import os
import random

import numpy as np
import torch
from networks.networks import *
from networks.blocks import *
from networks.skeleton_networks import *

import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
import tensorflow as tf
from collections import OrderedDict
from os.path import join as pjoin
import codecs as cs
from utils.utils import *
import time

class Logger:
    def __init__(self, log_dir):
        self.writer = tf.summary.create_file_writer(log_dir)

    def scalar_summary(self, tag, value, step):
        with self.writer.as_default():
            tf.summary.scalar(tag, value, step=step)
            self.writer.flush()

class BaseTrainer:
    @staticmethod
    def zero_grad(opt_list):
        for opt in opt_list:
            opt.zero_grad()

    @staticmethod
    def clip_norm(network_list):
        for network in network_list:
            clip_grad_norm_(network.parameters(), 0.5)

    @staticmethod
    def step(opt_list):
        for opt in opt_list:
            opt.step()

    @staticmethod
    def to(net_opt_list, device):
        for net_opt in net_opt_list:
            net_opt.to(device)

    @staticmethod
    def net_train(network_list):
        for network in network_list:
            network.train()

    @staticmethod
    def net_eval(network_list):
        for network in network_list:
            network.eval()

    @staticmethod
    def swap(x):
        "Swaps the ordering of the minibatch"
        shape = x.shape
        assert shape[0] % 2 == 0, "Minibatch size must be a multiple of 2"
        new_shape = [shape[0]//2, 2] + list(shape[1:])
        x = x.view(*new_shape)
        x = torch.flip(x, [1])
        return x.view(*shape)

    @staticmethod
    def grid_sample_1d(x, target_size, scale_range, num_crops=1):
        # build grid
        B = x.size(0) * num_crops
        unit_grid = torch.linspace(-1.0, 1.0, target_size, device=x.device).view(1, -1, 1, 1).expand(B, -1, -1, -1)
        #   (B, target_size, 1, 2)
        unit_grid = torch.cat([torch.ones_like(unit_grid) * -1, unit_grid], dim=3)

        # print(x.shape)
        #   (B // num_crops, D, Seq_len) -> (B, D, Seq_len, 1)
        x = x.unsqueeze(1).unsqueeze(-1).expand(-1, num_crops, -1, -1, -1).flatten(0, 1)

        scale = torch.rand(B, 1, 1, 1, device=x.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
        offset = (torch.rand(B, 1, 1, 1, device=x.device) * 2 - 1) * (1 - scale)
        sampling_grid = unit_grid * scale + offset
        sampling_grid[:, :, :, 0] = -1
        #   (B, D, target_size, 1)
        crop = F.grid_sample(x, sampling_grid, align_corners=False, padding_mode="border")
        #   (B, D, target_size, Seq_len)
        crop = crop.view(B // num_crops, num_crops, crop.size(1), crop.size(2))
        return crop

    # @staticmethod
    def get_random_crops(self, data, grid_sample=True):
        if grid_sample:
            return self.grid_sample_1d(data, self.opt.patch_size,
                                       (self.opt.patch_min_scale, self.opt.patch_max_scale),
                                       self.opt.num_crops)
        # scale = self.opt.patch_size
        B, D, L = data.shape
        data = data.unsqueeze(1).expand(-1, self.opt.num_crops, -1, -1).flatten(0, 1)
        start_idx = np.random.randint(0, L-self.opt.patch_size-1, data.shape[0])
        res_data = []
        for i in range(data.size(0)):
            res_data.append(data[i:i+1, :, start_idx[i]:start_idx[i]+self.opt.patch_size])
        data = torch.cat(res_data, dim=0)
        return data.view(B, -1, D, self.opt.patch_size)


class Trainer(BaseTrainer):
    def __init__(self, opt, encoder, generator, d_patch=None, dis=None):
        self.opt = opt
        self.E = encoder
        self.G = generator
        # if d_patch is not None:
        self.Dpatch = d_patch
        # if dis is not None:
        self.D = dis
        self.discriminator_iter = 0

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            self.l1_criterion = torch.nn.L1Loss()
            self.gan_criterion = gan_loss

    def save(self, file_name, ep, total_it):
        state = {
            "encoder": self.E.state_dict(),
            "generator": self.G.state_dict(),

            "opt_encoder": self.opt_E.state_dict(),
            "opt_generator": self.opt_G.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }
        if self.D is not None:
            state["dis"] = self.D.state_dict()
            state["opt_dis"] = self.opt_D.state_dict()

        if self.Dpatch is not None:
            state["patch_dis"] = self.Dpatch.state_dict()
            state["opt_patch_dis"] = self.opt_Dpatch.state_dict()
        torch.save(state, file_name)

    def resume(self, model_dir):
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.E.load_state_dict(checkpoint["encoder"])
        self.G.load_state_dict(checkpoint["generator"])

        if self.opt.is_train:
            self.opt_E.load_state_dict(checkpoint["opt_encoder"])
            self.opt_G.load_state_dict(checkpoint["opt_generator"])

            if self.D is not None:
                self.D.load_state_dict(checkpoint["dis"])
                self.opt_D.load_state_dict(checkpoint["opt_dis"])
            if self.Dpatch is not None:
                self.Dpatch.load_state_dict(checkpoint["patch_dis"])
                self.opt_Dpatch.load_state_dict(checkpoint["opt_patch_dis"])
        return checkpoint["ep"], checkpoint["total_it"]

    def forward(self, data):
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        real_data = data.permute(0, 2, 1).to(self.opt.device).float().detach()
        B = real_data.size(0)
        sp, gl = self.E(real_data[:, :-4])

        rec = self.G(sp, gl)
        sp_mix = self.swap(sp)

        mix = self.G(sp_mix, gl)

        if self.opt.do_cycle:
            self.msp, self.mgl = self.E(mix[:, :-4])

        # print(real_data.shape)
        # print(rec.shape)
        # print(mix.shape)
        # print(sp_mix.shape)
        self.real_data = real_data
        self.rec = rec
        self.mix = mix
        self.sp_mix = sp_mix
        self.sp, self.gl = sp, gl

    def backward_G(self):
        B = self.real_data.size(0)
        loss_dict = OrderedDict()
        self.loss_G_rec = self.l1_criterion(self.rec, self.real_data)
        self.loss_G = self.loss_G_rec
        loss_dict["loss_G_rec"] = self.loss_G_rec.item()

        if self.opt.do_cycle:
            self.loss_G_sp_rec = self.l1_criterion(self.msp, self.sp_mix.detach())
            self.loss_G_gl_rec = self.l1_criterion(self.mgl, self.gl.detach())
            self.loss_G += self.loss_G_sp_rec * self.opt.lambda_sp_rec
            self.loss_G += self.loss_G_gl_rec * self.opt.lambda_gl_rec
            loss_dict["loss_G_sp_rec"] = self.loss_G_sp_rec.item()
            loss_dict["loss_G_gl_rec"] = self.loss_G_gl_rec.item()


        if self.opt.do_gan:
            self.loss_G_gan_rec = self.gan_criterion(self.D(self.rec),
                                                     should_be_classified_as_real=True)
            self.loss_G_gan_mix = self.gan_criterion(self.D(self.mix),
                                                     should_be_classified_as_real=True)
            self.loss_G += self.loss_G_gan_rec * self.opt.lambda_gan * 0.5
            self.loss_G += self.loss_G_gan_mix * self.opt.lambda_gan
            loss_dict["loss_G_gan_rec"] = self.loss_G_gan_rec.item()
            loss_dict["loss_G_gan_mix"] = self.loss_G_gan_mix.item()

        if self.opt.do_patch_gan:
            if self.real_feat is not None:
                real_feat = self.real_feat.detach()
            else:
                real_feat = self.Dpatch.extract_features(
                    self.get_random_crops(self.real_data,
                                          grid_sample=self.opt.do_grid_sample),
                    aggregate=self.opt.do_patch_agg).detach()
            mix_feat = self.Dpatch.extract_features(
                self.get_random_crops(
                    self.mix, grid_sample=self.opt.do_grid_sample
                ))
            self.loss_G_gan_patch = self.gan_criterion(
                self.Dpatch.discriminate_features(real_feat, mix_feat),
                should_be_classified_as_real=True)
            self.real_feat = None
            self.loss_G += self.loss_G_gan_patch * self.opt.lambda_patch_gan
            loss_dict["loss_G_gan_patch"] = self.loss_G_gan_patch.item()
        loss_dict["loss_G"] = self.loss_G.item()

        return loss_dict

    def backward_D(self):
        # B = self.real_data.shape[0]
        pred_real = self.D(self.real_data.detach())
        pred_rec = self.D(self.rec.detach())
        pred_mix = self.D(self.mix.detach())
        loss_dict = OrderedDict()

        loss_D_real = self.gan_criterion(pred_real, should_be_classified_as_real=True)
        loss_D_rec = self.gan_criterion(pred_rec, should_be_classified_as_real=False)
        loss_D_mix = self.gan_criterion(pred_mix, should_be_classified_as_real=False)
        loss_dict["loss_D_real"] = loss_D_real.item()
        loss_dict["loss_D_rec"] = loss_D_rec.item()
        loss_dict["loss_D_mix"] = loss_D_mix.item()
        self.loss_D = loss_D_real * self.opt.lambda_gan + \
                      (loss_D_rec + loss_D_mix) * self.opt.lambda_gan * 0.5
        loss_dict["loss_D"] = self.loss_D.item()
        return loss_dict

    def backward_Dpatch(self):
        loss_dict = OrderedDict()
        self.real_crops = self.get_random_crops(
            self.real_data, grid_sample=self.opt.do_grid_sample).detach()
        self.target_crops = self.get_random_crops(
            self.real_data, grid_sample=self.opt.do_grid_sample).detach()
        self.real_feat = self.Dpatch.extract_features(
            self.real_crops,
            aggregate=self.opt.do_patch_agg
        )

        target_feat = self.Dpatch.extract_features(
            self.target_crops
        )

        mix_feat = self.Dpatch.extract_features(
            self.get_random_crops(self.mix, grid_sample=self.opt.do_grid_sample).detach()
        )

        loss_Dpatch_real = self.gan_criterion(
            self.Dpatch.discriminate_features(self.real_feat, target_feat),
            should_be_classified_as_real=True)
        loss_Dpatch_mix = self.gan_criterion(
            self.Dpatch.discriminate_features(self.real_feat, mix_feat),
            should_be_classified_as_real=False
        )
        self.loss_Dpatch = (loss_Dpatch_real + loss_Dpatch_mix) * self.opt.lambda_patch_gan
        loss_dict["loss_Dpatch_real"] = loss_Dpatch_real.item()
        loss_dict["loss_Dpatch_mix"] = loss_Dpatch_mix.item()
        loss_dict["loss_Dpatch"] = self.loss_Dpatch.item()
        return loss_dict

    def backward_R1(self):
        loss_dict = OrderedDict()
        if self.opt.lambda_R1 > 0.0:
            self.real_data.requires_grad_()
            pred_real = self.D(self.real_data).sum()
            grad_real, = torch.autograd.grad(
                outputs=pred_real,
                inputs=self.real_data,
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )
            grad_real2 = grad_real.pow(2)
            # dims = list(range(1, grad_real2.ndim))
            grad_penalty = grad_real2.sum() / self.real_data.size(0)
            loss_dict["grad_D_penalty"] = grad_penalty.item()
        else:
            grad_penalty = 0.0

        if self.opt.lambda_patch_R1 > 0.0:
            self.real_crops.requires_grad_()
            self.target_crops.requires_grad_()

            real_feat = self.Dpatch.extract_features(self.real_crops, aggregate=self.opt.do_patch_agg)
            target_feat = self.Dpatch.extract_features(self.target_crops)

            pred_real_patch = self.Dpatch.discriminate_features(real_feat, target_feat).sum()

            grad_real, grad_target = torch.autograd.grad(
                outputs=pred_real_patch,
                inputs=[self.real_crops, self.target_crops],
                create_graph=True,
                retain_graph=True,
                only_inputs=True,
            )
            # dims = list(range(1, grad_real.ndim))
            grad_crop_penalty = grad_real.pow(2).sum() / self.real_crops.shape[0] \
                                + grad_target.pow(2).sum() / self.target_crops.shape[0]
            loss_dict["grad_Dpatch_penalty"] = grad_crop_penalty.item()
        else:
            grad_crop_penalty = 0
        self.grad_penalty = grad_penalty * self.opt.lambda_R1 * 0.5 + \
                            grad_crop_penalty * self.opt.lambda_patch_R1 * 0.5 * 0.5
        loss_dict["grad_penalty"] = self.grad_penalty.item()
        return loss_dict

    def update(self):
        loss_dict = OrderedDict()

        if self.opt.do_gan:
            self.zero_grad([self.opt_D])
            losses = self.backward_D()
            self.loss_D.backward(retain_graph=True)
            self.clip_norm([self.D])
            self.step([self.opt_D])
            loss_dict.update(losses)

        if self.opt.do_patch_gan:
            self.zero_grad([self.opt_Dpatch])
            losses = self.backward_Dpatch()
            self.loss_Dpatch.backward(retain_graph=True)
            self.clip_norm([self.Dpatch])
            self.step([self.opt_Dpatch])
            loss_dict.update(losses)

            self.discriminator_iter += 1

        self.zero_grad([self.opt_E, self.opt_G])
        losses = self.backward_G()
        self.loss_G.backward()
        self.clip_norm([self.E, self.G])
        self.step([self.opt_E, self.opt_G])
        loss_dict.update(losses)


        if self.opt.do_gan or self.opt.do_patch_gan:
            if self.discriminator_iter % self.opt.R1_every_iter == 0:
                self.zero_grad([self.D, self.Dpatch])
                losses = self.backward_R1()
                self.grad_penalty.backward()
                self.step([self.opt_D, self.opt_Dpatch])
                loss_dict.update(losses)
        return loss_dict

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.G, self.E]
        if self.opt.do_gan:
            net_list.append(self.D)
        if self.opt.do_patch_gan:
            net_list.append(self.Dpatch)
        self.to(net_list, self.opt.device)

        self.opt_E = optim.Adam(self.E.parameters(), lr=self.opt.lr, betas=(0, 0.99))
        self.opt_G = optim.Adam(self.G.parameters(), lr=self.opt.lr, betas=(0, 0.99))

        c = self.opt.R1_every_iter / (1 + self.opt.R1_every_iter)
        if self.opt.do_gan:
            self.opt_D = optim.Adam(self.D.parameters(), lr=self.opt.lr * c, betas=(0, 0.99**c))

        if self.opt.do_patch_gan:
            self.opt_Dpatch = optim.Adam(self.Dpatch.parameters(), lr=self.opt.lr*c, betas=(0, 0.99**c))

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d"%(len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        min_val_epoch = epoch
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value/self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            if epoch % self.opt.save_every_e == 0:
                self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward_G()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v
            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f '%(k, val_loss[k])
            print(print_str)

            if val_loss["loss_G"] < min_val_loss:
                min_val_loss = val_loss["loss_G"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                B = self.real_data.size(0)
                data = torch.cat([self.real_data[:3], self.real_data[B//2:B//2+3], self.rec[:3], self.mix[:3]],
                                 dim=0)
                data = data.permute(0, 2, 1).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d"%(epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir)

class SkeletonTrainer(BaseTrainer):
    def __init__(self, opt, encoder, decoder):
        self.opt = opt
        self.encoder = encoder
        self.decoder = decoder
        # self.discriminator = discriminator

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()
            # self.cross_entropy_criterion = torch.nn.CrossEntropyLoss()

    @staticmethod
    def reparametrize(mu, logvar):
        s_var = logvar.mul(0.5).exp_()
        eps = s_var.data.new(s_var.size()).normal_()
        return eps.mul(s_var).add_(mu)

    @staticmethod
    def ones_like(tensor, val=1.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def kl_criterion(mu1, logvar1, mu2, logvar2):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        sigma1 = logvar1.mul(0.5).exp()
        sigma2 = logvar2.mul(0.5).exp()
        kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / (
                2 * torch.exp(logvar2)) - 1 / 2
        return kld.sum() / np.prod(mu1.shape)

    @staticmethod
    def kl_criterion_unit(mu, logvar):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2
        return kld.sum() / np.prod(mu.shape)

    def forward(self, batch_data):
        if self.opt.dataset_name == "cmu":
            M1, M2, A1, S1, SID1 = batch_data
        else:
            M1, M2, MS, _, A1, S1, SID1, _, _ = batch_data
        A2, S2 = A1, S1

        if self.opt.use_style:
            M2 = MS.clone()
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0

        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()
        M3 = self.swap(M2)
        SID3 = self.swap(SID1)

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
            A3 = self.swap(A2)
        else:
            A1, A2, A3 = None, None, None

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
            S3 = self.swap(S2)
        else:
            S1, S2, S3 = None, None, None

        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)
        sp3, gl_mu3, gl_logvar3 = self.swap(sp2), self.swap(gl_mu2), self.swap(gl_logvar2)

        if self.opt.use_vae:
            # z_sp1 = self.reparametrize(sp_mu1, sp_logvar1)
            z_sp1 = sp1
            z_gl1 = self.reparametrize(gl_mu1, gl_logvar1)

            # z_sp2 = self.reparametrize(sp_mu2, sp_logvar2)
            z_sp2 = sp2
            # May detach the graph of M1
            z_gl2 = self.reparametrize(gl_mu2, gl_logvar2)

            # May detach the graph of M2
            # z_sp3 = self.reparametrize(sp_mu2.detach(), sp_logvar2.detach())
            z_sp3 = z_sp2.detach()
            z_gl3 = self.reparametrize(gl_mu3, gl_logvar3)
        else:
            z_sp1 = sp1
            z_gl1 = gl_mu1

            z_sp2 = sp2
            z_gl2 = gl_mu1

            z_sp3 = sp2
            z_gl3 = gl_mu3

        # print("z_sp1",z_sp1.isinf().sum().item())
        # print("z_sp2",z_sp2.isinf().sum().item())
        # print("z_sp3",z_sp3.isinf().sum().item())
        # print("z_gl1",z_gl1.isinf().sum().item())
        # print("z_gl2",z_gl2.isinf().sum().item())
        # print("z_gl3",z_gl3.isinf().sum().item())

        RM1 = self.decoder(z_sp1, z_gl1, A1, S1)
        RM2 = self.decoder(z_sp2, z_gl2, A2, S2)
        RM3 = self.decoder(z_sp3, z_gl3, A2, S3)

        # print(RM1)
        # print(RM3)
        # Should be identical to M2
        # May detach from graph of RM3
        sp4, gl_mu4, gl_logvar4 = self.encoder(RM3[:, :-4], A2, S3)
        # print("gl_mu4", gl_mu4.isinf().sum().item())
        # print("gl_logvar4", gl_logvar4.isinf().sum().item())
        # print("sp_mu4", sp_mu4.isinf().sum().item())
        # print("sp_logvar4", sp_logvar4.isinf().sum().item())

        if self.opt.use_vae:
            # z_sp4 = self.reparametrize(sp_mu4, sp_logvar4)
            z_sp4 = sp4
            # May detach from graph of M2
            z_gl4 = self.reparametrize(gl_mu2.detach(), gl_logvar2.detach())
            # print("gl_mu2", gl_mu2)
            # print("gl_logvar2", gl_logvar2)
            #  Should be identical to M3
            # May detach from graph of M3
            # z_sp5 = self.reparametrize(sp_mu3.detach(), sp_logvar3.detach())
            z_sp5 = sp3.detach()
            # print("gl_mu4", gl_mu4)
            # print("gl_logvar4", gl_logvar4)
            z_gl5 = self.reparametrize(gl_mu4, gl_logvar4)
            # print("z_gl5", z_gl5)
        else:
            z_sp4 = sp4
            z_gl4 = gl_mu2

            z_sp5 = sp3
            z_gl5 = gl_mu4

        # print("z_sp4", z_sp4.isinf().sum().item())
        # print("z_gl4", z_gl4.isinf().sum().item())
        # print("z_sp5", z_sp5.isinf().sum().item())
        # print("z_gl5", z_gl5.isinf().sum().item())

        RRM2 = self.decoder(z_sp4, z_gl4, A2, S2)
        RRM3 = self.decoder(z_sp5, z_gl5, A3, S3)
        # RRM3 = RM3
        # print(SID1)
        # print(SID3)
        self.RM3 = RM3
        self.SID1 = SID1
        self.SID3 = SID3
        self.M1, self.M2, self.M3 = M1, M2, M3
        self.RM1, self.RM2, self.RRM2, self.RRM3 = RM1, RM2, RRM2, RRM3
        self.gl_mu1, self.gl_mu2, self.gl_mu3, self.gl_mu4 = gl_mu1, gl_mu2, gl_mu3, gl_mu4
        self.sp2, self.sp4 = sp2, sp4
        self.gl_logvar1, self.gl_logvar2, self.gl_logvar3, self.gl_logvar4 = gl_logvar1, gl_logvar2, gl_logvar3, gl_logvar4

    def generate(self, M1, M2, A1, A2, S1, S2, sampling, label_switch):
        # M1, _, A1, S1, SID1 = batch_data
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        # OM1 = M1.clone()
        # OM2 = self.swap(OM1)
        M1 = M1.clone()
        M2 = M2.clone()
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0
        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
        else:
            S1, S2 = None, None

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
        else:
            A1, A2 = None, None
        # print(M1[:, :-4].shape, M2[:, :-4].shape, S1.shape, S2.shape)
        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)

        # z_sp = self.reparametrize(sp_mu1, sp_logvar1)
        z_sp = sp1
        if sampling:
            # Sample from normal distribution, novel style generation
            z_gl = self.reparametrize(self.zeros_like(gl_mu1), self.zeros_like(gl_logvar1))
        else:
            # Sample from M2 distribution, motion style transfer
            z_gl = self.reparametrize(gl_mu2, gl_logvar2)

        if label_switch:
            S = S2
        else:
            S = S1
        TM = self.decoder(z_sp, z_gl, A1, S)

        return TM.permute(0, 2, 1)

    def generatev2(self, M1, M2, S2, sampling):
        # M1, _, A1, S1, SID1 = batch_data
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        # OM1 = M1.clone()
        # OM2 = self.swap(OM1)
        M1 = M1.clone()
        M2 = M2.clone()
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0
        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()

        if self.opt.use_style:
            S2 = S2.to(self.opt.device).float().detach()
        else:
            S2 = None

        # print(M1[:, :-4].shape, M2[:, :-4].shape, S1.shape, S2.shape)
        sp1 = self.encoder.extract_content_feature(M1[:, :-4], None)
        gl_mu2, gl_logvar2 = self.encoder.extract_style_feature(M2[:, :-4], S2)

        # z_sp = self.reparametrize(sp_mu1, sp_logvar1)
        z_sp = sp1
        if sampling:
            # Sample from normal distribution, novel style generation
            z_gl = self.reparametrize(self.zeros_like(gl_mu2), self.zeros_like(gl_logvar2))
        else:
            # Sample from M2 distribution, motion style transfer
            z_gl = self.reparametrize(gl_mu2, gl_logvar2)

        TM = self.decoder(z_sp, z_gl, None, S2)

        return TM.permute(0, 2, 1)


    def backward(self):
        self.loss_rec_m1 = self.l1_criterion(self.M1, self.RM1)
        self.loss_rec_m2 = self.l1_criterion(self.M2, self.RM2)
        self.loss_rec_rm2 = self.l1_criterion(self.M2, self.RRM2)
        self.loss_rec_rm3 = self.l1_criterion(self.M3, self.RRM3)

        self.loss_rec_m3r = self.l1_criterion(self.M2[:, :3], self.RM3[:, :3])
        self.loss_rec_m3f = self.l1_criterion(self.M2[:, -4:], self.RM3[:, -4:])

        self.loss_rec_lat = self.l1_criterion(self.sp2, self.sp4)

        # print(self.loss_rec_m1, self.loss_rec_m2, self.loss_rec_rm2, self.loss_rec_rm3)

        if self.opt.use_vae:
            self.loss_kld_gl_m1 = self.kl_criterion_unit(self.gl_mu1, self.gl_logvar1)
            self.loss_kld_gl_m2 = self.kl_criterion_unit(self.gl_mu2, self.gl_logvar2)
            self.loss_kld_gl_m4 = self.kl_criterion_unit(self.gl_mu4, self.gl_logvar4)
            self.loss_kld_gl_m12 = self.kl_criterion(self.gl_mu1, self.gl_logvar1, self.gl_mu2, self.gl_logvar2)
            self.loss_kld_gl_m34 = self.kl_criterion(self.gl_mu3, self.gl_logvar3, self.gl_mu4, self.gl_logvar4)
        else:
            self.loss_kld_gl_m1 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m2 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m4 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m12 = torch.zeros(1, device=self.gl_mu1.device)
        self.loss = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
                    (self.loss_rec_rm2 + self.loss_rec_rm3) * self.opt.lambda_rec_c + \
                    (self.loss_kld_gl_m1 + self.loss_kld_gl_m2 + self.loss_kld_gl_m4) * self.opt.lambda_kld_gl + \
                    (self.loss_kld_gl_m12) * self.opt.lambda_kld_gl12
        # self.loss = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
        #             (self.loss_kld_sp_m1 + self.loss_kld_sp_m2) * self.opt.lambda_kld_sp + \
        #             (self.loss_kld_gl_m1 + self.loss_kld_gl_m2) * self.opt.lambda_kld_gl
        loss_logs = OrderedDict({})
        loss_logs["loss"] = self.loss.item()
        loss_logs["loss_rec_m1"] = self.loss_rec_m1.item()
        loss_logs["loss_rec_m2"] = self.loss_rec_m2.item()
        loss_logs["loss_rec_rm2"] = self.loss_rec_rm2.item()
        loss_logs["loss_rec_rm3"] = self.loss_rec_rm3.item()
        loss_logs["loss_rec_lat"] = self.loss_rec_lat.item()

        loss_logs["loss_kld_gl_m1"] = self.loss_kld_gl_m1.item()
        loss_logs["loss_kld_gl_m2"] = self.loss_kld_gl_m2.item()
        loss_logs["loss_kld_gl_m4"] = self.loss_kld_gl_m4.item()
        loss_logs["loss_kld_gl_m12"] = self.loss_kld_gl_m12.item()
        loss_logs["loss_kld_gl_m34"] = self.loss_kld_gl_m34.item()


        return loss_logs

    def update(self):
        self.zero_grad([self.opt_encoder, self.opt_decoder])
        loss_logs = self.backward()
        self.loss.backward()
        self.clip_norm([self.encoder, self.decoder])
        self.step([self.opt_encoder, self.opt_decoder])
        return loss_logs

    def save(self, file_name, ep, total_it):
        state = {
            "encoder": self.encoder.state_dict(),
            "decoder": self.decoder.state_dict(),

            "opt_encoder": self.opt_encoder.state_dict(),
            "opt_decoder": self.opt_decoder.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])

        if self.opt.is_train:
            self.opt_encoder.load_state_dict(checkpoint["opt_encoder"])
            self.opt_decoder.load_state_dict(checkpoint["opt_decoder"])
        print("Loading the model from epoch %04d"%checkpoint["ep"])
        return checkpoint["ep"], checkpoint["total_it"]

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.encoder, self.decoder]
        self.to(net_list, self.opt.device)

        self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
        self.opt_decoder = optim.Adam(self.decoder.parameters(), lr=self.opt.lr)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss"], epoch)
            print(print_str)

            if val_loss["loss"] < min_val_loss:
                min_val_loss = val_loss["loss"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                B = self.M1.size(0)
                data = torch.cat([self.M2[:6:2], self.RM2[:6:2], self.M3[:6:2], self.RM3[:6:2]],
                                 dim=0)
                styles = torch.cat([self.SID1[:6:2], self.SID1[:6:2], self.SID3[:6:2], self.SID3[:6:2]],
                                   dim=0).detach().cpu().numpy()
                data = data.permute(0, 2, 1).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir, styles)


class VAEGANTrainer(BaseTrainer):
    def __init__(self, opt, encoder, decoder, discriminator):
        self.opt = opt
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()
            self.cls_criterion = torch.nn.CrossEntropyLoss()

    @staticmethod
    def reparametrize(mu, logvar):
        s_var = logvar.mul(0.5).exp_()
        eps = s_var.data.new(s_var.size()).normal_()
        return eps.mul(s_var).add_(mu)

    @staticmethod
    def ones_like(tensor, val=1.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def kl_criterion(mu1, logvar1, mu2, logvar2):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        sigma1 = logvar1.mul(0.5).exp()
        sigma2 = logvar2.mul(0.5).exp()
        kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / (
                2 * torch.exp(logvar2)) - 1 / 2
        return kld.sum() / np.prod(mu1.shape)

    @staticmethod
    def kl_criterion_unit(mu, logvar):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2
        return kld.sum() / np.prod(mu.shape)

    def forward(self, batch_data):
        M1, _, M2, _, A1, S1, SID1, _, _ = batch_data
        A2, S2 = A1, S1
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0

        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()
        M3 = self.swap(M2)
        SID3 = self.swap(SID1)


        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
            A3 = self.swap(A2)
        else:
            A1, A2, A3 = None, None, None

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
            S3 = self.swap(S2)
        else:
            S1, S2, S3 = None, None, None

        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)
        sp3, gl_mu3, gl_logvar3 = self.swap(sp2), self.swap(gl_mu2), self.swap(gl_logvar2)

        if self.opt.use_vae:
            # z_sp1 = self.reparametrize(sp_mu1, sp_logvar1)
            z_sp1 = sp1
            z_gl1 = self.reparametrize(gl_mu1, gl_logvar1)

            # z_sp2 = self.reparametrize(sp_mu2, sp_logvar2)
            z_sp2 = sp2
            # May detach the graph of M1
            z_gl2 = self.reparametrize(gl_mu2, gl_logvar2)

            # May detach the graph of M2
            # z_sp3 = self.reparametrize(sp_mu2.detach(), sp_logvar2.detach())
            z_sp3 = z_sp2.detach()
            z_gl3 = self.reparametrize(gl_mu3, gl_logvar3)
        else:
            z_sp1 = sp1
            z_gl1 = gl_mu1

            z_sp2 = sp2
            z_gl2 = gl_mu1

            z_sp3 = sp2
            z_gl3 = gl_mu3

        RM1 = self.decoder(z_sp1, z_gl1, A1, S1)
        RM2 = self.decoder(z_sp2, z_gl2, A2, S2)
        RM3 = self.decoder(z_sp3, z_gl3, A2, S3)

        z = torch.randn_like(gl_mu1, device=gl_mu1.device)
        NRM3 = self.decoder(z_sp3, z, A2, S3)

        # Should be identical to M2
        # May detach from graph of RM3
        sp4, gl_mu4, gl_logvar4 = self.encoder(RM3[:, :-4], A2, S3)

        spN, gl_muN, gl_logvarN = self.encoder(NRM3[:, :-4], A2, S3)

        if self.opt.use_vae:
            # z_sp4 = self.reparametrize(sp_mu4, sp_logvar4)
            z_sp4 = sp4
            # May detach from graph of M2
            z_gl4 = self.reparametrize(gl_mu2.detach(), gl_logvar2.detach())
            # print("gl_mu2", gl_mu2)
            # print("gl_logvar2", gl_logvar2)
            #  Should be identical to M3
            # May detach from graph of M3
            # z_sp5 = self.reparametrize(sp_mu3.detach(), sp_logvar3.detach())
            z_sp5 = sp3.detach()
            # print("gl_mu4", gl_mu4)
            # print("gl_logvar4", gl_logvar4)
            z_gl5 = self.reparametrize(gl_mu4, gl_logvar4)

            z_glN = self.reparametrize(gl_mu2.detach(), gl_logvar2.detach())
            z_spN = spN
            # print("z_gl5", z_gl5)
        else:
            z_sp4 = sp4
            z_gl4 = gl_mu2.detach()

            z_sp5 = sp3.detach()
            z_gl5 = gl_mu4

            z_spN = spN
            z_glN = gl_mu2.detach()


        RRM2 = self.decoder(z_sp4, z_gl4, A2, S2)
        RRM3 = self.decoder(z_sp5, z_gl5, A3, S3)
        RRMN = self.decoder(z_spN, z_glN, A2, S2)

        self.batch_data = batch_data
        self.RM3, self.NRM3 = RM3, NRM3
        self.SID1 = SID1.to(self.opt.device)
        self.SID3 = SID3.to(self.opt.device)
        self.M1, self.M2, self.M3 = M1, M2, M3
        self.RM1, self.RM2, self.RRM2, self.RRM3, self.RRMN = RM1, RM2, RRM2, RRM3, RRMN
        self.gl_mu1, self.gl_mu2, self.gl_mu3, self.gl_mu4, self.gl_muN = gl_mu1, gl_mu2, gl_mu3, gl_mu4, gl_muN
        self.gl_logvar1, self.gl_logvar2, self.gl_logvar3, self.gl_logvar4, self.gl_logvarN = gl_logvar1, gl_logvar2, gl_logvar3, gl_logvar4, gl_logvarN

    def generate(self, M1, M2, A1, A2, S1, S2, sampling, label_switch):
        # M1, _, A1, S1, SID1 = batch_data
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        # OM1 = M1.clone()
        # OM2 = self.swap(OM1)
        M1 = M1.clone()
        M2 = M2.clone()
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0
        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
        else:
            S1, S2= None, None

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
        else:
            A1, A2 = None, None
        # print(M1[:, :-4].shape, M2[:, :-4].shape, S1.shape, S2.shape)
        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)

        # z_sp = self.reparametrize(sp_mu1, sp_logvar1)
        z_sp = sp1
        if sampling:
            # Sample from normal distribution, novel style generation
            z_gl = self.reparametrize(self.zeros_like(gl_mu1), self.zeros_like(gl_logvar1))
        else:
            # Sample from M2 distribution, motion style transfer
            z_gl = self.reparametrize(gl_mu2, gl_logvar2)

        if label_switch:
            S = S2
        else:
            S = S1
        TM = self.decoder(z_sp, z_gl, A1, S)

        return TM.permute(0, 2, 1)

    def backward_EG_VAE(self):
        self.loss_rec_m1 = self.l1_criterion(self.M1, self.RM1)
        self.loss_rec_m2 = self.l1_criterion(self.M2, self.RM2)
        self.loss_rec_rm2 = self.l1_criterion(self.M2, self.RRM2)
        self.loss_rec_rm3 = self.l1_criterion(self.M3, self.RRM3)
        self.loss_rec_rmN = self.l1_criterion(self.M2, self.RRMN)

        # print(self.loss_rec_m1, self.loss_rec_m2, self.loss_rec_rm2, self.loss_rec_rm3)

        if self.opt.use_vae:
            self.loss_kld_gl_m1 = self.kl_criterion_unit(self.gl_mu1, self.gl_logvar1)
            self.loss_kld_gl_m2 = self.kl_criterion_unit(self.gl_mu2, self.gl_logvar2)
            self.loss_kld_gl_m4 = self.kl_criterion_unit(self.gl_mu4, self.gl_logvar4)
            self.loss_kld_gl_mN = self.kl_criterion_unit(self.gl_muN, self.gl_logvarN)
            self.loss_kld_gl_m12 = self.kl_criterion(self.gl_mu1, self.gl_logvar1, self.gl_mu2, self.gl_logvar2)
        else:
            self.loss_kld_gl_m1 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m2 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m4 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m12 = torch.zeros(1, device=self.gl_mu1.device)

        self.loss_EG_vae = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
                    (self.loss_rec_rm2 + self.loss_rec_rm3) * self.opt.lambda_rec_c + \
                    (self.loss_kld_gl_m1 + self.loss_kld_gl_m2 + self.loss_kld_gl_m4) * self.opt.lambda_kld_gl + \
                    self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12
        # self.loss = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
        #             (self.loss_kld_sp_m1 + self.loss_kld_sp_m2) * self.opt.lambda_kld_sp + \
        #             (self.loss_kld_gl_m1 + self.loss_kld_gl_m2) * self.opt.lambda_kld_gl
        loss_logs = OrderedDict({})
        loss_logs["loss_EG_vae"] = self.loss_EG_vae.item()
        loss_logs["loss_rec_m1"] = self.loss_rec_m1.item()
        loss_logs["loss_rec_m2"] = self.loss_rec_m2.item()
        loss_logs["loss_rec_rm2"] = self.loss_rec_rm2.item()
        loss_logs["loss_rec_rm3"] = self.loss_rec_rm3.item()
        loss_logs["loss_rec_rmN"] = self.loss_rec_rmN.item()

        loss_logs["loss_kld_gl_m1"] = self.loss_kld_gl_m1.item()
        loss_logs["loss_kld_gl_m2"] = self.loss_kld_gl_m2.item()
        loss_logs["loss_kld_gl_m4"] = self.loss_kld_gl_m4.item()
        loss_logs["loss_kld_gl_mN"] = self.loss_kld_gl_mN.item()

        loss_logs["loss_kld_gl_m12"] = self.loss_kld_gl_m12.item()

        return loss_logs

    def backward_G_ADV(self):
        # if random.random() > 0.5:
        #     _, recn_dis_pred, _, recn_cls_pred = self.discriminator(torch.cat([self.RM1, self.RRM2], dim=0)[:, :-4])
        # else:
        #     _, recn_dis_pred, _, recn_cls_pred = self.discriminator(torch.cat([self.RRMN, self.RRM3], dim=0)[:, :-4])
        _, tran_dis_pred, _, tran_cls_pred = self.discriminator(self.RM3[:, :-4])
        _, gene_dis_pred, _, gene_cls_pred = self.discriminator(self.NRM3[:, :-4])

        real_label = self.zeros_like(tran_dis_pred[:, 0], 0).long()
        # real_label_recn = self.zeros_like(recn_dis_pred[..., 0], 0)

        self.loss_G_dis = self.cls_criterion(tran_dis_pred, real_label)
                          # self.cls_criterion(gene_dis_pred, real_label)
                          # self.cls_criterion(recn_dis_pred, real_label_recn)
        self.loss_G_adv = 0
        # print(self.loss_G_adv.item())

        if self.opt.use_style:
            self.loss_G_cls = self.cls_criterion(tran_cls_pred, self.SID3)
                              # self.cls_criterion(gene_cls_pred, self.SID3)
            self.loss_G_adv = self.loss_G_cls
            # print(self.loss_G_adv.item())

        self.loss_G_adv *= self.opt.lambda_adv
        # print(self.loss_G_adv.item())


        loss_logs = OrderedDict({})
        loss_logs["loss_G_adv"] = self.loss_G_adv.item()
        loss_logs["loss_G_dis"] = self.loss_G_dis.item()
        loss_logs["loss_G_cls"] = self.loss_G_cls.item()

        return loss_logs


    def backward_D(self):
        # real_samples = torch.cat([self.])
        _, real_dis_pred, _, real_cls_pred = self.discriminator(self.M2[:, :-4].detach())
        # if random.random() > 0.5:
        #     _, recn_dis_pred, _, recn_cls_pred = self.discriminator(torch.cat([self.RM1, self.RRM2], dim=0)[:, :-4])
        # else:
        #     _, recn_dis_pred, _, recn_cls_pred = self.discriminator(torch.cat([self.RRMN, self.RRM3], dim=0)[:, :-4])
        _, tran_dis_pred, _, _ = self.discriminator(self.RM3[:, :-4].detach())
        _, gene_dis_pred, _, _ = self.discriminator(self.NRM3[:, :-4].detach())

        real_label = self.zeros_like(real_dis_pred[:, 0], 0).long()
        tran_label = self.zeros_like(tran_dis_pred[:, 0], 1).long()
        gene_label = self.zeros_like(gene_dis_pred[:, 0], 2).long()
        # print(real_dis_pred.shape)
        # recn_label = self.zeros_like(recn_dis_pred[..., 0], 3)

        self.loss_D_dis = self.cls_criterion(real_dis_pred, real_label) + \
                          self.cls_criterion(tran_dis_pred, tran_label)
                          # self.cls_criterion(gene_dis_pred, gene_label)
                          # self.cls_criterion(recn_dis_pred, recn_label)

        self.loss_D_adv = 0

        if self.opt.use_style:
            self.loss_D_cls = self.cls_criterion(real_cls_pred, self.SID1)
            self.loss_D_adv = self.loss_D_cls

        self.loss_D_adv *= self.opt.lambda_adv

        loss_logs = OrderedDict({})
        loss_logs["loss_D_adv"] = self.loss_D_adv.item()
        loss_logs["loss_D_dis"] = self.loss_D_dis.item()
        loss_logs["loss_D_cls"] = self.loss_D_cls.item()

        return loss_logs

    def backward_R1(self):
        loss_dict = OrderedDict()
        self.M2.requires_grad_()
        _, real_dis_pred, _, real_cls_pred = self.discriminator(self.M2[:, :-4])
        pred_real = real_cls_pred.mean()
        grad_real, = torch.autograd.grad(
            outputs=pred_real,
            inputs=self.M2,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )
        grad_real2 = grad_real.pow(2)
        # dims = list(range(1, grad_real2.ndim))
        grad_penalty = grad_real2.sum() / self.M2.size(0)
        self.grad_penalty = grad_penalty * self.opt.lambda_R1
        loss_dict["grad_penalty"] = self.grad_penalty.item()
        return loss_dict

    def update(self):
        loss_dict = OrderedDict()

        self.zero_grad([self.opt_encoder, self.opt_decoder])
        loss_logs = self.backward_EG_VAE()
        loss_dict.update(loss_logs)
        self.loss_EG_vae.backward()
        # self.clip_norm([self.encoder, self.decoder])
        self.step([self.opt_encoder, self.opt_decoder])

        self.forward(self.batch_data)
        self.zero_grad([self.discriminator])
        loss_logs = self.backward_D()
        loss_dict.update(loss_logs)
        self.loss_D_adv.backward(retain_graph=True)
        # self.clip_norm([self.discriminator])
        self.step([self.opt_discriminator])

        self.zero_grad([self.discriminator])
        loss_logs = self.backward_R1()
        loss_dict.update(loss_logs)
        self.grad_penalty.backward()
        # self.clip_norm([self.discriminator])
        self.step([self.opt_discriminator])

        self.zero_grad([self.opt_decoder])
        loss_logs = self.backward_G_ADV()
        loss_dict.update(loss_logs)
        self.loss_G_adv.backward()
        # self.clip_norm([self.encoder, self.decoder])
        self.step([self.opt_decoder])
        return loss_dict
    
    def save(self, file_name, ep, total_it):
        state = {
            "encoder": self.encoder.state_dict(),
            "decoder": self.decoder.state_dict(),
            "discriminator": self.discriminator.state_dict(),

            "opt_encoder": self.opt_encoder.state_dict(),
            "opt_decoder": self.opt_decoder.state_dict(),
            "opt_discriminator": self.opt_discriminator.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])

        if self.opt.is_train:
            self.discriminator.load_state_dict(checkpoint["discriminator"])
            self.opt_discriminator.load_state_dict(checkpoint["opt_discriminator"])
            self.opt_encoder.load_state_dict(checkpoint["opt_encoder"])
            self.opt_decoder.load_state_dict(checkpoint["opt_decoder"])

        return checkpoint["ep"], checkpoint["total_it"]

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.encoder, self.decoder, self.discriminator]
        self.to(net_list, self.opt.device)

        self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
        self.opt_decoder = optim.Adam(self.decoder.parameters(), lr=self.opt.lr)
        self.opt_discriminator = optim.Adam(self.discriminator.parameters(), lr=self.opt.lr)


        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d"%(epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d"%(len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value/self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward_EG_VAE()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f '%(k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss_EG_vae"], epoch)
            print(print_str)

            if val_loss["loss_EG_vae"] < min_val_loss:
                min_val_loss = val_loss["loss_EG_vae"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                B = self.M1.size(0)
                data = torch.cat([self.M2[:6:2], self.RM2[:6:2], self.M3[:6:2], self.RM3[:6:2]],
                                 dim=0)
                styles = torch.cat([self.SID1[:6:2], self.SID1[:6:2], self.SID3[:6:2], self.SID3[:6:2]], dim=0).cpu().detach().numpy()
                data = data.permute(0, 2, 1).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d"%(epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir, styles)


class VAEGANV2Trainer(BaseTrainer):
    def __init__(self, opt, encoder, decoder, discriminator):
        self.opt = opt
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()
            self.cls_criterion_ls = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
            self.cls_criterion = torch.nn.CrossEntropyLoss()

    @staticmethod
    def reparametrize(mu, logvar):
        s_var = logvar.mul(0.5).exp_()
        eps = s_var.data.new(s_var.size()).normal_()
        return eps.mul(s_var).add_(mu)

    @staticmethod
    def ones_like(tensor, val=1.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def kl_criterion(mu1, logvar1, mu2, logvar2):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        sigma1 = logvar1.mul(0.5).exp()
        sigma2 = logvar2.mul(0.5).exp()
        kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / (
                2 * torch.exp(logvar2)) - 1 / 2
        return kld.sum() / np.prod(mu1.shape)

    @staticmethod
    def kl_criterion_unit(mu, logvar):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2
        return kld.sum() / np.prod(mu.shape)

    def forward(self, batch_data):
        # M1, _, M2, _, A1, S1, SID1, _, _ = batch_data
        M1, M2, MS, _, A1, S1, SID1, _, _ = batch_data

        A2, S2 = A1, S1
        if self.opt.use_style:
            M2 = MS.clone()
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0
        MS[..., 1:3] *= 0

        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()
        MS = MS.permute(0, 2, 1).to(self.opt.device).float().detach()
        M3 = self.swap(M2)
        SID3 = self.swap(SID1)

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
            A3 = self.swap(A2)
        else:
            A1, A2, A3 = None, None, None

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
            S3 = self.swap(S2)
        else:
            S1, S2, S3 = None, None, None

        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)
        sp3, gl_mu3, gl_logvar3 = self.swap(sp2), self.swap(gl_mu2), self.swap(gl_logvar2)

        if self.opt.use_vae:
            # z_sp1 = self.reparametrize(sp_mu1, sp_logvar1)
            z_sp1 = sp1
            z_gl1 = self.reparametrize(gl_mu1, gl_logvar1)

            # z_sp2 = self.reparametrize(sp_mu2, sp_logvar2)
            z_sp2 = sp2
            # May detach the graph of M1
            z_gl2 = self.reparametrize(gl_mu2, gl_logvar2)

            # May detach the graph of M2
            # z_sp3 = self.reparametrize(sp_mu2.detach(), sp_logvar2.detach())
            z_sp3 = z_sp2.detach()
            z_gl3 = self.reparametrize(gl_mu3, gl_logvar3)
        else:
            z_sp1 = sp1
            z_gl1 = gl_mu1

            z_sp2 = sp2
            z_gl2 = gl_mu1

            z_sp3 = sp2
            z_gl3 = gl_mu3

        RM1 = self.decoder(z_sp1, z_gl1, A1, S1)
        RM2 = self.decoder(z_sp2, z_gl2, A2, S2)
        RM3 = self.decoder(z_sp3, z_gl3, A2, S3)

        z = torch.randn_like(gl_mu1, device=gl_mu1.device)
        NRM3 = self.decoder(z_sp3, z, A2, S3)

        # Should be identical to M2
        # May detach from graph of RM3
        sp4, gl_mu4, gl_logvar4 = self.encoder(RM3[:, :-4], A2, S3)

        spN, gl_muN, gl_logvarN = self.encoder(NRM3[:, :-4], A2, S3)

        if self.opt.use_vae:
            # z_sp4 = self.reparametrize(sp_mu4, sp_logvar4)
            z_sp4 = sp4
            # May detach from graph of M2
            z_gl4 = self.reparametrize(gl_mu2.detach(), gl_logvar2.detach())
            # print("gl_mu2", gl_mu2)
            # print("gl_logvar2", gl_logvar2)
            #  Should be identical to M3
            # May detach from graph of M3
            # z_sp5 = self.reparametrize(sp_mu3.detach(), sp_logvar3.detach())
            z_sp5 = sp3.detach()
            # print("gl_mu4", gl_mu4)
            # print("gl_logvar4", gl_logvar4)
            z_gl5 = self.reparametrize(gl_mu4, gl_logvar4)

            z_glN = self.reparametrize(gl_mu2.detach(), gl_logvar2.detach())
            z_spN = spN
            # print("z_gl5", z_gl5)
        else:
            z_sp4 = sp4
            z_gl4 = gl_mu2.detach()

            z_sp5 = sp3.detach()
            z_gl5 = gl_mu4

            z_spN = spN
            z_glN = gl_mu2.detach()

        RRM2 = self.decoder(z_sp4, z_gl4, A2, S2)
        RRM3 = self.decoder(z_sp5, z_gl5, A3, S3)
        RRMN = self.decoder(z_spN, z_glN, A2, S2)

        self.batch_data = batch_data
        self.RM3, self.NRM3 = RM3, NRM3
        self.SID1 = SID1.to(self.opt.device)
        self.SID3 = SID3.to(self.opt.device)
        self.S3, self.S2 = S3, S2
        self.M1, self.M2, self.M3, self.MS = M1, M2, M3, MS
        self.RM1, self.RM2, self.RRM2, self.RRM3, self.RRMN = RM1, RM2, RRM2, RRM3, RRMN
        self.gl_mu1, self.gl_mu2, self.gl_mu3, self.gl_mu4, self.gl_muN = gl_mu1, gl_mu2, gl_mu3, gl_mu4, gl_muN
        self.gl_logvar1, self.gl_logvar2, self.gl_logvar3, self.gl_logvar4, self.gl_logvarN = gl_logvar1, gl_logvar2, gl_logvar3, gl_logvar4, gl_logvarN

    def generate(self, M1, M2, A1, A2, S1, S2, sampling, label_switch):
        # M1, _, A1, S1, SID1 = batch_data
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        # OM1 = M1.clone()
        # OM2 = self.swap(OM1)
        M1 = M1.clone()
        M2 = M2.clone()
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0
        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
        else:
            S1, S2 = None, None

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
        else:
            A1, A2 = None, None
        # print(M1[:, :-4].shape, M2[:, :-4].shape, S1.shape, S2.shape)
        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)

        # z_sp = self.reparametrize(sp_mu1, sp_logvar1)
        z_sp = sp1
        if sampling:
            # Sample from normal distribution, novel style generation
            z_gl = self.reparametrize(self.zeros_like(gl_mu1), self.zeros_like(gl_logvar1))
        else:
            # Sample from M2 distribution, motion style transfer
            z_gl = self.reparametrize(gl_mu2, gl_logvar2)

        if label_switch:
            S = S2
        else:
            S = S1
        TM = self.decoder(z_sp, z_gl, A1, S)

        return TM.permute(0, 2, 1)

    def backward_EG_VAE(self):
        self.loss_rec_m1 = self.l1_criterion(self.M1, self.RM1)
        self.loss_rec_m2 = self.l1_criterion(self.M2, self.RM2)
        self.loss_rec_rm2 = self.l1_criterion(self.M2, self.RRM2)
        self.loss_rec_rm3 = self.l1_criterion(self.M3, self.RRM3)
        self.loss_rec_rmN = self.l1_criterion(self.M2, self.RRMN)

        # print(self.loss_rec_m1, self.loss_rec_m2, self.loss_rec_rm2, self.loss_rec_rm3)

        if self.opt.use_vae:
            self.loss_kld_gl_m1 = self.kl_criterion_unit(self.gl_mu1, self.gl_logvar1)
            self.loss_kld_gl_m2 = self.kl_criterion_unit(self.gl_mu2, self.gl_logvar2)
            self.loss_kld_gl_m4 = self.kl_criterion_unit(self.gl_mu4, self.gl_logvar4)
            self.loss_kld_gl_mN = self.kl_criterion_unit(self.gl_muN, self.gl_logvarN)
            self.loss_kld_gl_m12 = self.kl_criterion(self.gl_mu1, self.gl_logvar1, self.gl_mu2, self.gl_logvar2)
        else:
            self.loss_kld_gl_m1 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m2 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m4 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m12 = torch.zeros(1, device=self.gl_mu1.device)

        # if self.opt.is_simple:
        # self.loss_EG_vae = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
        #                    (self.loss_rec_rm2 + self.loss_rec_rm3) * self.opt.lambda_rec_c + \
        #                    (self.loss_kld_gl_m1 + self.loss_kld_gl_m2 + self.loss_kld_gl_m4) * self.opt.lambda_kld_gl + \
        #                    self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12
        self.loss_EG_vae = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
                           (self.loss_kld_gl_m1 + self.loss_kld_gl_m2) * self.opt.lambda_kld_gl + \
                           self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12

        loss_logs = OrderedDict({})
        loss_logs["loss_EG_vae"] = self.loss_EG_vae.item()
        loss_logs["loss_rec_m1"] = self.loss_rec_m1.item()
        loss_logs["loss_rec_m2"] = self.loss_rec_m2.item()
        loss_logs["loss_rec_rm2"] = self.loss_rec_rm2.item()
        loss_logs["loss_rec_rm3"] = self.loss_rec_rm3.item()
        loss_logs["loss_rec_rmN"] = self.loss_rec_rmN.item()

        loss_logs["loss_kld_gl_m1"] = self.loss_kld_gl_m1.item()
        loss_logs["loss_kld_gl_m2"] = self.loss_kld_gl_m2.item()
        loss_logs["loss_kld_gl_m4"] = self.loss_kld_gl_m4.item()
        loss_logs["loss_kld_gl_mN"] = self.loss_kld_gl_mN.item()

        loss_logs["loss_kld_gl_m12"] = self.loss_kld_gl_m12.item()

        return loss_logs

    def backward_G_ADV(self):
        # if random.random() > 0.5:
        #     _, recn_dis_pred, _, recn_cls_pred = self.discriminator(torch.cat([self.RM1, self.RRM2], dim=0)[:, :-4])
        # else:
        #     _, recn_dis_pred, _, recn_cls_pred = self.discriminator(torch.cat([self.RRMN, self.RRM3], dim=0)[:, :-4])
        loss_logs = OrderedDict({})

        if self.opt.adv_mode == "patch":
            if not self.opt.use_style:
                raise "Patch Discriminator is Incompatible with Unsupervision"
            tran_dis_pred = self.discriminator(self.RM3[:, :-4], self.S3)
            if not self.opt.is_simple:
                gene_dis_pred = self.discriminator(self.NRM3[:, :-4], self.S3)
        elif self.opt.adv_mode == "paired":
            tran_dis_pred = self.discriminator.discriminate_feature(
                self.discriminator.extract_feature(self.RM3[:, :-4]), self.m3_pair_feat.detach()
            )
            if not self.opt.is_simple:
                gene_dis_pred = self.discriminator.discriminate_feature(
                    self.discriminator.extract_feature(self.NRM3[:, :-4]), self.m3_pair_feat.detach()
                )
        else:
            raise "Unrecognized adverarial learning mode"

        real_label = self.zeros_like(tran_dis_pred[:, 0], 0).long()
        # real_label_recn = self.zeros_like(recn_dis_pred[..., 0], 0)
        pred_tran_label = tran_dis_pred.argmax(dim=1)
        acc_tran = (pred_tran_label == real_label).sum() / np.prod(pred_tran_label.shape)
        loss_logs["G_acc_tran"] = acc_tran.item()

        if self.opt.is_simple:
            self.loss_G_adv = self.cls_criterion(tran_dis_pred, real_label)

        else:
            self.loss_G_adv = self.cls_criterion(tran_dis_pred, real_label) + \
                              self.cls_criterion(gene_dis_pred, real_label)
            pred_gene_label = gene_dis_pred.argmax(dim=1)
            acc_gene = (pred_gene_label == real_label).sum() / np.prod(pred_gene_label.shape)
            loss_logs["G_acc_gene"] = acc_gene.item()

        self.loss_G_adv = self.loss_G_adv *  self.opt.lambda_adv
        # print(self.loss_G_adv.item())

        loss_logs["loss_G_adv"] = self.loss_G_adv.item()

        return loss_logs

    def backward_D(self):
        loss_logs = OrderedDict({})

        if self.opt.adv_mode == "patch":
            if not self.opt.use_style:
                raise "Patch Discriminator is Incompatible with Unsupervision"

            real_dis_pred = self.discriminator(self.M2[:, :-4].detach(), self.S2)
            tran_dis_pred = self.discriminator(self.RM3[:, :-4].detach(), self.S3)
            if not self.opt.is_simple:
                gene_dis_pred = self.discriminator(self.NRM3[:, :-4].detach(), self.S3)
        elif self.opt.adv_mode == "paired":
            if self.opt.use_style:
                m2_pair_feat = self.discriminator.extract_feature(self.MS[:, :-4].detach())
            else:
                m2_pair_feat = self.discriminator.extract_feature(self.M1[:, :-4].detach())
            m3_pair_feat = self.discriminator.extract_feature(self.M3[:, :-4].detach())
            real_dis_pred = self.discriminator.discriminate_feature(
                self.discriminator.extract_feature(self.M2[:, :-4].detach()),
                m2_pair_feat
            )
            tran_dis_pred = self.discriminator.discriminate_feature(
                self.discriminator.extract_feature(self.RM3[:, :-4].detach()),
                m3_pair_feat
            )
            self.m3_pair_feat = m3_pair_feat
            if not self.opt.is_simple:
                gene_dis_pred = self.discriminator.discriminate_feature(
                    self.discriminator.extract_feature(self.NRM3[:, :-4].detach()),
                    m3_pair_feat
                )
        else:
            raise "Unrecognized Adversarial Learning Mode"

        real_label = self.zeros_like(real_dis_pred[:, 0], 0).long()
        tran_label = self.zeros_like(tran_dis_pred[:, 0], 1).long()
        if self.opt.is_simple:
            self.loss_D_adv = self.cls_criterion_ls(real_dis_pred, real_label) + \
                              self.cls_criterion_ls(tran_dis_pred, tran_label)
        else:
            gene_label = self.zeros_like(gene_dis_pred[:, 0], 2).long()
            self.loss_D_adv = self.cls_criterion_ls(real_dis_pred, real_label) + \
                              self.cls_criterion_ls(tran_dis_pred, tran_label) + \
                              self.cls_criterion_ls(gene_dis_pred, gene_label)
            pred_gene_label = gene_dis_pred.argmax(dim=1)
            acc_gene = (pred_gene_label == gene_label).sum() / np.prod(pred_gene_label.shape)
            loss_logs["D_acc_gene"] = acc_gene.item()

        pred_real_label = real_dis_pred.argmax(dim=1)
        pred_tran_label = tran_dis_pred.argmax(dim=1)

        acc_real = (pred_real_label == real_label).sum() / np.prod(pred_real_label.shape)
        acc_tran = (pred_tran_label == tran_label).sum() / np.prod(pred_tran_label.shape)
        # self.loss_D_adv *= self.opt.lambda_adv

        loss_logs["D_acc_real"] = acc_real.item()
        loss_logs["D_acc_tran"] = acc_tran.item()
        loss_logs["loss_D_adv"] = self.loss_D_adv.item()
        return loss_logs

    def backward_R1(self):
        loss_dict = OrderedDict()
        if self.opt.adv_mode == "patch":
            O1, O2 = self.M2, self.S2
            O1.requires_grad_()
            O2.requires_grad_()
            pred_real = self.discriminator(O1[:, :-4], O2)
        elif self.opt.adv_mode == "paired":
            O1 = self.M2
            if self.opt.use_style:
                O2 = self.MS
            else:
                O2 = self.M1
            O1.requires_grad_()
            O2.requires_grad_()
            pred_real = self.discriminator.discriminate_feature(
                self.discriminator.extract_feature(O1[:, :-4]),
                self.discriminator.extract_feature(O2[:, :-4])
            )
        else:
            raise "Unrecognized Adversarial Learning Mode"

        # pred_real = real_cls_pred.mean()
        grad_o1, grad_o2= torch.autograd.grad(
            outputs=pred_real.mean(),
            inputs=[O1, O2],
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )
        grad_penalty = (grad_o1.pow(2).sum() + grad_o2.pow(2).sum()) / self.M2.size(0)
        # dims = list(range(1, grad_real2.ndim))
        # grad_penalty = grad_real2.sum() / self.M2.size(0)
        self.grad_penalty = grad_penalty * self.opt.lambda_R1
        loss_dict["grad_penalty"] = self.grad_penalty.item()
        return loss_dict

    def update(self):
        loss_dict = OrderedDict()
        if not self.opt.use_style:
            if not self.opt.is_simple:
                raise "Must Use Simple Mode for Unsupervised Learning"
        self.zero_grad([self.opt_encoder, self.opt_decoder])
        loss_logs = self.backward_EG_VAE()
        loss_dict.update(loss_logs)
        self.loss_EG_vae.backward()
        # self.clip_norm([self.encoder, self.decoder])
        self.step([self.opt_encoder, self.opt_decoder])

        if self.epoch >= self.opt.adv_start_ep:
            self.forward(self.batch_data)
            self.zero_grad([self.discriminator])
            loss_logs = self.backward_D()
            loss_dict.update(loss_logs)
            self.loss_D_adv.backward()
            # self.clip_norm([self.discriminator])
            self.step([self.opt_discriminator])

            self.zero_grad([self.discriminator])
            loss_logs = self.backward_R1()
            loss_dict.update(loss_logs)
            self.grad_penalty.backward()
            # self.clip_norm([self.discriminator])
            self.step([self.opt_discriminator])

            self.zero_grad([self.opt_decoder])
            loss_logs = self.backward_G_ADV()
            loss_dict.update(loss_logs)
            self.loss_G_adv.backward()
            # self.clip_norm([self.encoder, self.decoder])
            self.step([self.opt_decoder])
        return loss_dict

    def save(self, file_name, ep, total_it):
        state = {
            "encoder": self.encoder.state_dict(),
            "decoder": self.decoder.state_dict(),
            "discriminator": self.discriminator.state_dict(),

            "opt_encoder": self.opt_encoder.state_dict(),
            "opt_decoder": self.opt_decoder.state_dict(),
            "opt_discriminator": self.opt_discriminator.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])

        if self.opt.is_train:
            self.discriminator.load_state_dict(checkpoint["discriminator"])
            self.opt_discriminator.load_state_dict(checkpoint["opt_discriminator"])
            self.opt_encoder.load_state_dict(checkpoint["opt_encoder"])
            self.opt_decoder.load_state_dict(checkpoint["opt_decoder"])

        return checkpoint["ep"], checkpoint["total_it"]

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.encoder, self.decoder, self.discriminator]
        self.to(net_list, self.opt.device)

        self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
        self.opt_decoder = optim.Adam(self.decoder.parameters(), lr=self.opt.lr)
        self.opt_discriminator = optim.Adam(self.discriminator.parameters(), lr=self.opt.lr*4)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.epoch = epoch
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward_EG_VAE()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss_EG_vae"], epoch)
            print(print_str)

            if val_loss["loss_EG_vae"] < min_val_loss:
                min_val_loss = val_loss["loss_EG_vae"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                B = self.M1.size(0)
                data = torch.cat([self.M2[:6:2], self.RM2[:6:2], self.M3[:6:2], self.RM3[:6:2]],
                                 dim=0)
                styles = torch.cat([self.SID1[:6:2], self.SID1[:6:2], self.SID3[:6:2], self.SID3[:6:2]],
                                   dim=0).cpu().detach().numpy()
                data = data.permute(0, 2, 1).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir, styles)


class VAELatentAdvTrainer(BaseTrainer):
    def __init__(self, opt, encoder, decoder, discriminator):
        self.opt = opt
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()
            # self.cls_criterion_ls = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
            self.cls_criterion = torch.nn.CrossEntropyLoss()

    @staticmethod
    def reparametrize(mu, logvar):
        s_var = logvar.mul(0.5).exp_()
        eps = s_var.data.new(s_var.size()).normal_()
        return eps.mul(s_var).add_(mu)

    @staticmethod
    def ones_like(tensor, val=1.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def kl_criterion(mu1, logvar1, mu2, logvar2):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        sigma1 = logvar1.mul(0.5).exp()
        sigma2 = logvar2.mul(0.5).exp()
        kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / (
                2 * torch.exp(logvar2)) - 1 / 2
        return kld.sum() / np.prod(mu1.shape)

    @staticmethod
    def kl_criterion_unit(mu, logvar):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2
        return kld.sum() / np.prod(mu.shape)

    def forward(self, batch_data):
        # M1, _, M2, _, A1, S1, SID1, _, _ = batch_data
        M1, M2, MS, MD, A1, S1, SID1, _, _ = batch_data

        A2, S2 = A1, S1
        if self.opt.use_style:
            M2 = MS.clone()
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0
        # MS[..., 1:3] *= 0
        MD[..., 1:3] *= 0

        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()
        MD = MD.permute(0, 2, 1).to(self.opt.device).float().detach()
        M3 = self.swap(M2)
        SID3 = self.swap(SID1)

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
            A3 = self.swap(A2)
        else:
            A1, A2, A3 = None, None, None

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
            S3 = self.swap(S2)
        else:
            S1, S2, S3 = None, None, None

        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)
        sp3, gl_mu3, gl_logvar3 = self.swap(sp2), self.swap(gl_mu2), self.swap(gl_logvar2)

        if self.opt.use_vae:
            # z_sp1 = self.reparametrize(sp_mu1, sp_logvar1)
            z_sp1 = sp1
            z_gl1 = self.reparametrize(gl_mu1, gl_logvar1)

            # z_sp2 = self.reparametrize(sp_mu2, sp_logvar2)
            z_sp2 = sp2
            # May detach the graph of M1
            z_gl2 = self.reparametrize(gl_mu2, gl_logvar2)

            # May detach the graph of M2
            # M2 content + M3 Style
            z_sp3 = z_sp2.detach()
            z_gl3 = self.reparametrize(gl_mu3, gl_logvar3)
        else:
            z_sp1 = sp1
            z_gl1 = gl_mu1

            z_sp2 = sp2
            z_gl2 = gl_mu1

            z_sp3 = sp2
            z_gl3 = gl_mu3

        # Reconstruction for M1 and M2
        RM1 = self.decoder(z_sp1, z_gl1, A1, S1)
        RM2 = self.decoder(z_sp2, z_gl2, A2, S2)
        # M2 content + M3 Style
        RM3 = self.decoder(z_sp3, z_gl3, A2, S3)

        ####################
        #####Cycle Loss#####
        ####################

        # Should be identical to M2
        # May detach from graph of RM3
        sp4, gl_mu4, gl_logvar4 = self.encoder(RM3[:, :-4], A2, S3)

        if self.opt.use_vae:
            # z_sp4 = self.reparametrize(sp_mu4, sp_logvar4)
            z_sp4 = sp4
            # May detach from graph of M2
            z_gl4 = self.reparametrize(gl_mu2.detach(), gl_logvar2.detach())


            #  Should be identical to M3
            # May detach from graph of M3
            # z_sp5 = self.reparametrize(sp_mu3.detach(), sp_logvar3.detach())
            z_sp5 = sp3.detach()
            # print("gl_mu4", gl_mu4)
            # print("gl_logvar4", gl_logvar4)
            z_gl5 = self.reparametrize(gl_mu4, gl_logvar4)
        else:
            z_sp4 = sp4
            z_gl4 = gl_mu2.detach()

            z_sp5 = sp3.detach()
            z_gl5 = gl_mu4

        RRM2 = self.decoder(z_sp4, z_gl4, A2, S2)
        RRM3 = self.decoder(z_sp5, z_gl5, A3, S3)

        self.batch_data = batch_data
        self.RM3 = RM3
        self.SID1 = SID1.to(self.opt.device)
        self.SID3 = SID3.to(self.opt.device)
        self.S3, self.S2 = S3, S2
        self.M1, self.M2, self.M3, self.MD = M1, M2, M3, MD
        self.RM1, self.RM2, self.RRM2, self.RRM3 = RM1, RM2, RRM2, RRM3
        self.gl_mu1, self.gl_mu2, self.gl_mu3, self.gl_mu4 = gl_mu1, gl_mu2, gl_mu3, gl_mu4
        self.gl_logvar1, self.gl_logvar2, self.gl_logvar3, self.gl_logvar4 = gl_logvar1, gl_logvar2, gl_logvar3, gl_logvar4
        self.sp2, self.sp4 = sp2, sp4

    def generate(self, M1, M2, A1, A2, S1, S2, sampling, label_switch):
        # M1, _, A1, S1, SID1 = batch_data
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        # OM1 = M1.clone()
        # OM2 = self.swap(OM1)
        M1 = M1.clone()
        M2 = M2.clone()
        M1[..., 1:3] *= 0
        M2[..., 1:3] *= 0
        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
        else:
            S1, S2 = None, None

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
        else:
            A1, A2 = None, None
        # print(M1[:, :-4].shape, M2[:, :-4].shape, S1.shape, S2.shape)
        sp1, gl_mu1, gl_logvar1 = self.encoder(M1[:, :-4], A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(M2[:, :-4], A2, S2)

        # z_sp = self.reparametrize(sp_mu1, sp_logvar1)
        z_sp = sp1
        if sampling:
            # Sample from normal distribution, novel style generation
            z_gl = self.reparametrize(self.zeros_like(gl_mu1), self.zeros_like(gl_logvar1))
        else:
            # Sample from M2 distribution, motion style transfer
            z_gl = self.reparametrize(gl_mu2, gl_logvar2)

        if label_switch:
            S = S2
        else:
            S = S1
        TM = self.decoder(z_sp, z_gl, A1, S)

        return TM.permute(0, 2, 1)

    def backward_EG(self, first_iter=False):
        self.loss_rec_m1 = self.l1_criterion(self.M1, self.RM1)
        self.loss_rec_m2 = self.l1_criterion(self.M2, self.RM2)
        self.loss_rec_rm2 = self.l1_criterion(self.M2, self.RRM2)
        self.loss_rec_rm3 = self.l1_criterion(self.M3, self.RRM3)
        self.loss_rec_lat = self.l1_criterion(self.sp2, self.sp4)
        self.loss_lat_reg = torch.mean(torch.abs(self.sp2))

        # print(self.loss_rec_m1, self.loss_rec_m2, self.loss_rec_rm2, self.loss_rec_rm3)

        if self.opt.use_vae:
            self.loss_kld_gl_m1 = self.kl_criterion_unit(self.gl_mu1, self.gl_logvar1)
            self.loss_kld_gl_m2 = self.kl_criterion_unit(self.gl_mu2, self.gl_logvar2)
            self.loss_kld_gl_m4 = self.kl_criterion_unit(self.gl_mu4, self.gl_logvar4)
            self.loss_kld_gl_m12 = self.kl_criterion(self.gl_mu1, self.gl_logvar1, self.gl_mu2, self.gl_logvar2)
            self.loss_kld_gl_m34 = self.kl_criterion(self.gl_mu3, self.gl_logvar3, self.gl_mu4, self.gl_logvar4)

        else:
            self.loss_kld_gl_m1 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m2 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m4 = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_kld_gl_m12 = torch.zeros(1, device=self.gl_mu1.device)

        if not first_iter:
            if self.opt.adv_mode == "patch":
                if not self.opt.use_style:
                    raise "Patch Discriminator is Incompatible with Unsupervision"
                cls_pred = self.discriminator(self.sp2)
                real_label = self.SID1[:,np.newaxis].repeat(1, cls_pred.shape[-1])
                self.loss_G_adv = -self.cls_criterion(cls_pred, real_label)
            elif self.opt.adv_mode == "paired":
                if self.opt.use_style:
                    # m2_pair_feat = self.discriminator.extract_feature(self.MS[:, :-4].detach())
                    other_feat = self.discriminator.extract_feature(
                        self.encoder.extract_content_feature(self.MD[:, :-4]))
                else:
                    # m2_pair_feat = self.discriminator.extract_feature(self.M1[:, :-4].detach())
                    other_feat = self.discriminator.extract_feature(self.sp3)
                m1_feat = self.discriminator.extract_feature(self.sp1)
                m2_feat = self.discriminator.extract_feature(self.sp2)
                pos_pair_pred = self.discriminator.discriminate_feature(m2_feat, m1_feat)
                neg_pair_pred = self.discriminator.discriminate_feature(other_feat, m2_feat)
                pos_label = self.ones_like(pos_pair_pred[:, 0]).long()
                neg_label = self.zeros_like(neg_pair_pred[:, 0]).long()
                if self.opt.adv_saturate:
                    self.loss_G_adv = (self.cls_criterion(pos_pair_pred, neg_label) +
                                       self.cls_criterion(neg_pair_pred, pos_label))
                else:
                    self.loss_G_adv = -(self.cls_criterion(pos_pair_pred, pos_label) +
                                        self.cls_criterion(neg_pair_pred, neg_label))
            self.loss_EG = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
                               (self.loss_rec_rm2 + self.loss_rec_rm3) * self.opt.lambda_rec_c + \
                               (self.loss_kld_gl_m1 + self.loss_kld_gl_m2 + self.loss_kld_gl_m4) * self.opt.lambda_kld_gl + \
                               self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12 + self.loss_G_adv*self.opt.lambda_adv + \
                           self.loss_lat_reg * 0.001
            # self.loss_EG = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
            #                (self.loss_kld_gl_m1 + self.loss_kld_gl_m2) * self.opt.lambda_kld_gl + \
            #                self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12 + self.loss_G_adv * self.opt.lambda_adv
        else:
            self.loss_G_adv = torch.zeros(1, device=self.gl_mu1.device)
            self.loss_EG = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
                               (self.loss_rec_rm2 + self.loss_rec_rm3) * self.opt.lambda_rec_c + \
                               (self.loss_kld_gl_m1 + self.loss_kld_gl_m2 + self.loss_kld_gl_m4) * self.opt.lambda_kld_gl + \
                               self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12 +\
                           self.loss_lat_reg * 0.001
            # self.loss_EG = (self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
            #                (self.loss_kld_gl_m1 + self.loss_kld_gl_m2) * self.opt.lambda_kld_gl + \
            #                self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12


        loss_logs = OrderedDict({})
        loss_logs["loss_EG_vae"] = self.loss_EG.item()
        loss_logs["loss_G_adv"] = self.loss_G_adv.item()
        loss_logs["loss_rec_m1"] = self.loss_rec_m1.item()
        loss_logs["loss_rec_m2"] = self.loss_rec_m2.item()
        loss_logs["loss_rec_rm2"] = self.loss_rec_rm2.item()
        loss_logs["loss_rec_rm3"] = self.loss_rec_rm3.item()
        loss_logs["loss_rec_lat"] = self.loss_rec_lat.item()
        loss_logs["loss_lat_reg"] = self.loss_lat_reg.item()

        loss_logs["loss_kld_gl_m1"] = self.loss_kld_gl_m1.item()
        loss_logs["loss_kld_gl_m2"] = self.loss_kld_gl_m2.item()
        loss_logs["loss_kld_gl_m4"] = self.loss_kld_gl_m4.item()

        loss_logs["loss_kld_gl_m12"] = self.loss_kld_gl_m12.item()
        loss_logs["loss_kld_gl_m34"] = self.loss_kld_gl_m34.item()


        return loss_logs


    def backward_D(self):
        loss_logs = OrderedDict({})

        if self.opt.adv_mode == "patch":
            if not self.opt.use_style:
                raise "Patch Discriminator is Incompatible with Unsupervision"
            cls_pred = self.discriminator(self.sp2.detach())
            real_label = self.SID1[:,np.newaxis].repeat(1, cls_pred.shape[-1])
            self.loss_D_adv = self.cls_criterion(cls_pred, real_label)
            pred_cls_label = cls_pred.argmax(dim=1)
            acc = (pred_cls_label == real_label).sum() / np.prod(pred_cls_label.shape)
            loss_logs["D_acc"] = acc.item()
        elif self.opt.adv_mode == "paired":
            if self.opt.use_style:
                # m2_pair_feat = self.discriminator.extract_feature(self.MS[:, :-4].detach())
                other_feat = self.discriminator.extract_feature(
                    self.encoder.extract_content_feature(self.MD[:, :-4]).detach())
            else:
                # m2_pair_feat = self.discriminator.extract_feature(self.M1[:, :-4].detach())
                other_feat = self.discriminator.extract_feature(self.sp3.detach())
            # other = self.discriminator.extract_feature(self.M3[:, :-4].detach())
            m1_feat = self.discriminator.extract_feature(self.sp1.detach())
            m2_feat = self.discriminator.extract_feature(self.sp2.detach())
            # print(m2_feat.shape, other_feat.shape)
            neg_pair_pred = self.discriminator.discriminate_feature(m2_feat, other_feat)
            pos_pair_pred = self.discriminator.discriminate_feature(m2_feat, m1_feat)
            pos_label = self.ones_like(pos_pair_pred[:, 0]).long()
            neg_label = self.zeros_like(neg_pair_pred[:, 0]).long()
            self.loss_D_adv = self.cls_criterion(neg_pair_pred, neg_label) + \
                              self.cls_criterion(pos_pair_pred, pos_label)
            pred_neg_label = neg_pair_pred.argmax(dim=1)
            acc_neg = (pred_neg_label == neg_label).sum() / np.prod(pred_neg_label.shape)
            loss_logs["D_acc_neg"] = acc_neg.item()
            pred_pos_label = pos_pair_pred.argmax(dim=1)
            acc_pos = (pred_pos_label == pos_label).sum() / np.prod(pred_pos_label.shape)
            loss_logs["D_acc_pos"] = acc_pos.item()
        else:
            raise "Unrecognized Adversarial Learning Mode"
        loss_logs["loss_D_adv"] = self.loss_D_adv.item()
        return loss_logs


    def backward_R1(self):
        loss_dict = OrderedDict()
        if self.opt.adv_mode == "patch":
            O1 = self.sp2.detach()
            O1.requires_grad_()
            pred_real = self.discriminator(O1)
            grad_o1,  = torch.autograd.grad(
                outputs=pred_real.mean(),
                inputs=O1,
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )
            grad_penalty = (grad_o1.pow(2).sum()) / self.M2.size(0)
            # dims = list(range(1, grad_real2.ndim))
            # grad_penalty = grad_real2.sum() / self.M2.size(0)
            self.grad_penalty = grad_penalty * self.opt.lambda_R1
            loss_dict["grad_penalty"] = self.grad_penalty.item()
        elif self.opt.adv_mode == "paired":
            O1 = self.sp2.detach()
            O2 = self.sp1.detach()
            O1.requires_grad_()
            O2.requires_grad_()
            pred_real = self.discriminator.discriminate_feature(
                self.discriminator.extract_feature(O1),
                self.discriminator.extract_feature(O2)
            )
            grad_o1, grad_o2 = torch.autograd.grad(
                outputs=pred_real.mean(),
                inputs=[O1, O2],
                create_graph=True,
                retain_graph=True,
                only_inputs=True
            )
            grad_penalty = (grad_o1.pow(2).sum() + grad_o2.pow(2).sum()) / self.M2.size(0)
            # dims = list(range(1, grad_real2.ndim))
            # grad_penalty = grad_real2.sum() / self.M2.size(0)
            self.grad_penalty = grad_penalty * self.opt.lambda_R1
            loss_dict["grad_penalty"] = self.grad_penalty.item()
        else:
            raise "Unrecognized Adversarial Learning Mode"

        # pred_real = real_cls_pred.mean()

        return loss_dict

    def update(self, it):
        loss_dict = OrderedDict()

        self.zero_grad([self.opt_encoder, self.opt_decoder])
        loss_logs = self.backward_EG(it == 0)
        loss_dict.update(loss_logs)
        self.loss_EG.backward()
        # self.clip_norm([self.encoder, self.decoder])
        self.step([self.opt_encoder, self.opt_decoder])

        # if self.epoch >= self.opt.adv_start_ep:
        # self.forward(self.batch_data)
        self.zero_grad([self.discriminator])
        loss_logs = self.backward_D()
        loss_dict.update(loss_logs)
        self.loss_D_adv.backward()
        # self.clip_norm([self.discriminator])
        self.step([self.opt_discriminator])

        self.zero_grad([self.discriminator])
        loss_logs = self.backward_R1()
        loss_dict.update(loss_logs)
        self.grad_penalty.backward()
        # self.clip_norm([self.discriminator])
        self.step([self.opt_discriminator])

        return loss_dict

    def save(self, file_name, ep, total_it):
        state = {
            "encoder": self.encoder.state_dict(),
            "decoder": self.decoder.state_dict(),
            "discriminator": self.discriminator.state_dict(),

            "opt_encoder": self.opt_encoder.state_dict(),
            "opt_decoder": self.opt_decoder.state_dict(),
            "opt_discriminator": self.opt_discriminator.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])

        if self.opt.is_train:
            self.discriminator.load_state_dict(checkpoint["discriminator"])
            self.opt_discriminator.load_state_dict(checkpoint["opt_discriminator"])
            self.opt_encoder.load_state_dict(checkpoint["opt_encoder"])
            self.opt_decoder.load_state_dict(checkpoint["opt_decoder"])

        return checkpoint["ep"], checkpoint["total_it"]

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.encoder, self.decoder, self.discriminator]
        self.to(net_list, self.opt.device)

        self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
        self.opt_decoder = optim.Adam(self.decoder.parameters(), lr=self.opt.lr)
        self.opt_discriminator = optim.Adam(self.discriminator.parameters(), lr=self.opt.lr*4)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.epoch = epoch
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update(it)

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                # self.it = it
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward_EG(first_iter=True)
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss_EG_vae"], epoch)
            print(print_str)

            if val_loss["loss_EG_vae"] < min_val_loss:
                min_val_loss = val_loss["loss_EG_vae"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                B = self.M1.size(0)
                data = torch.cat([self.M2[:6:2], self.RM2[:6:2], self.M3[:6:2], self.RM3[:6:2]],
                                 dim=0)
                styles = torch.cat([self.SID1[:6:2], self.SID1[:6:2], self.SID3[:6:2], self.SID3[:6:2]],
                                   dim=0).cpu().detach().numpy()
                data = data.permute(0, 2, 1).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir, styles)


class MotionAEKLTrainer(BaseTrainer):
    def __init__(self, opt, encoder, decoder):
        self.opt = opt
        self.encoder = encoder
        self.decoder = decoder

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()

    @staticmethod
    def reparametrize(mu, logvar):
        s_var = logvar.mul(0.5).exp_()
        eps = s_var.data.new(s_var.size()).normal_()
        return eps.mul(s_var).add_(mu)

    @staticmethod
    def ones_like(tensor, val=1.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def kl_criterion(mu1, logvar1, mu2, logvar2):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        sigma1 = logvar1.mul(0.5).exp()
        sigma2 = logvar2.mul(0.5).exp()
        kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / (
                2 * torch.exp(logvar2)) - 1 / 2
        return kld.sum() / np.prod(mu1.shape)

    @staticmethod
    def kl_criterion_unit(mu, logvar):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2
        return kld.sum() / np.prod(mu.shape)

    def forward(self, batch_data):
        M, _, _, _, _ = batch_data

        M = M.permute(0, 2, 1).to(self.opt.device).float().detach()
        # M3 = self.swap(M2)
        z, mu, logvar = self.encoder(M[:, :-4])
        # sp_mu3, sp_logvar3 = self.encoder(M3, S3)
        RM = self.decoder(z)
        self.M, self.RM = M, RM
        self.mu, self.logvar = mu, logvar
        self.z = z

    def backward(self):
        self.loss_rec = self.l1_criterion(self.M, self.RM)
        loss_logs = OrderedDict({})
        loss_logs["loss_rec"] = self.loss_rec.item()
        self.loss = self.loss_rec
        # print(self.loss_rec_m1, self.loss_rec_m2, self.loss_rec_rm2, self.loss_rec_rm3)

        if self.opt.use_vae:
            self.loss_kld = self.kl_criterion_unit(self.mu, self.logvar)
            loss_logs["loss_kld"] = self.loss_kld.item()
            self.loss += self.loss_kld * self.opt.lambda_kld
        else:
            self.loss_sparsity = torch.mean(torch.abs(self.z))
            self.loss_smooth = self.l1_criterion(self.z[..., 1:], self.z[..., :-1])
            loss_logs["loss_sparsity"] = self.loss_sparsity.item()
            loss_logs["loss_smooth"] = self.loss_smooth.item()
            self.loss += self.loss_smooth*self.opt.lambda_sms + self.loss_sparsity*self.opt.lambda_spa

        loss_logs["loss"] = self.loss.item()
        return loss_logs

    def update(self):
        self.zero_grad([self.opt_encoder, self.opt_decoder])
        loss_logs = self.backward()
        self.loss.backward()
        self.clip_norm([self.encoder, self.decoder])
        self.step([self.opt_encoder, self.opt_decoder])
        return loss_logs

    def save(self, file_name, ep, total_it):
        state = {
            "encoder": self.encoder.state_dict(),
            "decoder": self.decoder.state_dict(),

            "opt_encoder": self.opt_encoder.state_dict(),
            "opt_decoder": self.opt_decoder.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])

        if self.opt.is_train:
            self.opt_encoder.load_state_dict(checkpoint["opt_encoder"])
            self.opt_decoder.load_state_dict(checkpoint["opt_decoder"])

        return checkpoint["ep"], checkpoint["total_it"]

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.encoder, self.decoder]
        self.to(net_list, self.opt.device)

        self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
        self.opt_decoder = optim.Adam(self.decoder.parameters(), lr=self.opt.lr)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss"], epoch)
            print(print_str)

            if val_loss["loss"] < min_val_loss:
                min_val_loss = val_loss["loss"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                # B = self.M1.size(0)
                data = torch.cat([self.M[:6:2], self.RM[:6:2]], dim=0)
                data = data.permute(0, 2, 1).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir)


class GMRTrainer(BaseTrainer):
    def __init__(self, opt, regressor):
        self.opt = opt
        self.regressor = regressor
        # self.decoder = decoder

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            # self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()

    def save(self, file_name, ep, total_it):
        state = {
            "regressor": self.regressor.state_dict(),
            "opt_regressor": self.opt_regressor.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.regressor.load_state_dict(checkpoint["regressor"])

        if self.opt.is_train:
            self.opt_regressor.load_state_dict(checkpoint["opt_regressor"])
        return checkpoint["ep"], checkpoint["total_it"]

    def forward(self, batch_data):
        if self.opt.dataset_name == "cmu":
            M, _, _, _, _ = batch_data
            M = M[..., :-4]
        else:
            M = batch_data
        input = torch.cat([M[..., 0:1], M[..., 3:]], dim=-1)
        target = M[..., 1:3]

        add_noise = True if random.random() < self.opt.noise_prob else False
        if add_noise:
            input = input + torch.zeros_like(input).normal_() * self.opt.noise_scale * random.random()

        input = input.permute(0, 2, 1).float().to(self.opt.device)
        target = target.permute(0, 2, 1).float().to(self.opt.device)
        pred = self.regressor(input)
        self.pred = pred
        self.target = target
        self.input = input

    def update(self):
        self.zero_grad([self.opt_regressor])
        loss_logs = self.backward()
        self.loss.backward()
        self.clip_norm([self.regressor])
        self.step([self.opt_regressor])
        return loss_logs

    def backward(self):
        self.loss = self.l1_criterion(self.target, self.pred)

        loss_dict = OrderedDict({})
        loss_dict["loss"] = self.loss.item()
        return loss_dict

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.regressor]
        self.to(net_list, self.opt.device)

        self.opt_regressor = optim.Adam(self.regressor.parameters(), lr=self.opt.lr)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss"], epoch)
            print(print_str)

            if val_loss["loss"] < min_val_loss:
                min_val_loss = val_loss["loss"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                # B = self.M1.size(0)
                local = self.input[:2].permute(0, 2, 1)
                target = self.target[:2].permute(0, 2, 1)
                pred = self.pred[:2].permute(0, 2, 1)
                real = torch.cat([local[..., 0:1], target, local[..., 1:], local[..., -4:]], dim=-1)
                fake = torch.cat([local[..., 0:1], pred, local[..., 1:], local[..., -4:]], dim=-1)
                data = torch.cat([real, fake], dim=0).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir)


class ClassifierTrainer(BaseTrainer):
    def __init__(self, opt, classifier):
        self.opt = opt
        self.classifier = classifier
        # self.decoder = decoder

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            # self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()
            self.cls_criterion = torch.nn.CrossEntropyLoss()

    def save(self, file_name, ep, total_it):
        state = {
            "classifier": self.classifier.state_dict(),
            "opt_classifier": self.opt_classifier.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.classifier.load_state_dict(checkpoint["classifier"])

        if self.opt.is_train:
            self.opt_classifier.load_state_dict(checkpoint["opt_classifier"])
        return checkpoint["ep"], checkpoint["total_it"]

    def forward(self, batch_data):
        M1, M2, MS, MD, A1, S1, SID1, _, _ = batch_data

        M1 = M1.permute(0, 2, 1).float().to(self.opt.device)
        M2 = M2.permute(0, 2, 1).float().to(self.opt.device)
        SID1 = SID1.long().to(self.opt.device)

        _, pred1 = self.classifier(M1[:, :-4])
        _, pred2 = self.classifier(M2[:, :-4])

        self.pred1 = pred1
        self.pred2 = pred2
        self.target = SID1
        self.input = input

    def update(self):
        self.zero_grad([self.opt_classifier])
        loss_logs = self.backward()
        self.loss.backward()
        self.clip_norm([self.classifier])
        self.step([self.opt_classifier])
        return loss_logs

    def backward(self):
        self.loss1 = self.cls_criterion(self.pred1, self.target)
        self.loss2 = self.cls_criterion(self.pred2, self.target)
        self.loss = self.loss1

        pred_id1 = self.pred1.argmax(dim=-1)
        pred_id2 = self.pred2.argmax(dim=-1)
        correct1 = (pred_id1 == self.target).sum()
        correct2 = (pred_id2 == self.target).sum()
        accuracy = (correct2 + correct1) / len(self.target) / 2

        loss_dict = OrderedDict({})
        loss_dict["loss"] = self.loss.item()
        loss_dict["accuracy"] = accuracy.item()
        # print(self.loss, accuracy)
        return loss_dict

    def train(self, train_dataloader, val_dataloader):
        net_list = [self.classifier]
        self.to(net_list, self.opt.device)

        self.opt_classifier = optim.Adam(self.classifier.parameters(), lr=self.opt.lr)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss"], epoch)
            print(print_str)

            if val_loss["loss"] < min_val_loss:
                min_val_loss = val_loss["loss"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            # if epoch % self.opt.eval_every_e == 0:
            #     # B = self.M1.size(0)
            #     local = self.input[:2].permute(0, 2, 1)
            #     target = self.target[:2].permute(0, 2, 1)
            #     pred = self.pred[:2].permute(0, 2, 1)
            #     real = torch.cat([local[..., 0:1], target, local[..., 1:]], dim=-1)
            #     fake = torch.cat([local[..., 0:1], pred, local[..., 1:]], dim=-1)
            #     data = torch.cat([real, fake], dim=0).detach().cpu().numpy()
            #     save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
            #     os.makedirs(save_dir, exist_ok=True)
            #     plot_eval(data, save_dir)


class ActionClassifierTrainer(ClassifierTrainer):
    def __init__(self, opt, classifier):
        super().__init__(opt, classifier)

    def forward(self, batch_data):
        M, A, AID = batch_data

        M = M.float().to(self.opt.device)

        AID = AID.long().to(self.opt.device)

        _, pred = self.classifier(M[..., :-4])

        self.pred = pred
        self.target = AID
        # self.input = input

    def backward(self):
        self.loss = self.cls_criterion(self.pred, self.target)

        pred_id = self.pred.argmax(dim=-1)
        correct = (pred_id == self.target).sum()
        accuracy = correct / len(self.target)

        loss_dict = OrderedDict({})
        loss_dict["loss"] = self.loss.item()
        loss_dict["accuracy"] = accuracy.item()
        # print(self.loss, accuracy)
        return loss_dict

class ContrastiveTrainer(BaseTrainer):
    def __init__(self, opt, model):
        self.opt = opt
        self.model = model
        # self.decoder = decoder

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            # self.mse_criterion = torch.nn.MSELoss()
            # self.l1_criterion = torch.nn.SmoothL1Loss()
            # self.cls_criterion = torch.nn.CrossEntropyLoss()
            self.contrastive_loss = ContrastiveLoss(self.opt.negative_margin)

    def save(self, file_name, ep, total_it):
        state = {
            "model": self.model.state_dict(),
            "opt_model": self.opt_model.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.model.load_state_dict(checkpoint["model"])

        if self.opt.is_train:
            self.opt_model.load_state_dict(checkpoint["opt_model"])
        return checkpoint["ep"], checkpoint["total_it"]

    def forward(self, batch_data):
        # RM1: random motion
        # RM2: motion closed to RM1
        # RM3: motion from the same sequence of RM1
        # RM4: motion from different style of RM1
        RM1, RM2, RM3, RM4, A1, S1, SID1, S4, SID4 = batch_data
        # print(SID1)
        # print(SID4)
        M = RM1
        if self.opt.is_style:
            PM = RM3
            NM = RM4
        elif self.opt.is_content:
            PM = RM2
            NM = RM3 if random.random() > 0.5 else RM4
        else:
            raise "Must learn in one mode"

        M = M.permute(0, 2, 1).float().to(self.opt.device)
        PM = PM.permute(0, 2, 1).float().to(self.opt.device)
        NM = NM.permute(0, 2, 1).float().to(self.opt.device)

        featM = self.model(M[:, :-4])
        featNM = self.model(NM[:, :-4])
        featPM = self.model(PM[:, :-4])


        self.featM = featM
        self.featNM = featNM
        self.featPM = featPM
        # self.input = input

    def update(self):
        self.zero_grad([self.opt_model])
        loss_logs = self.backward()
        self.loss.backward()
        self.clip_norm([self.model])
        self.step([self.opt_model])
        return loss_logs

    def backward(self):
        batch_size = self.featM.shape[0]
        '''Positive pairs'''
        pos_labels = torch.zeros(batch_size).to(self.featM.device)
        self.loss_pos = self.contrastive_loss(self.featM, self.featPM, pos_labels)

        '''Negative Pairs, shifting index'''
        neg_labels = torch.ones(batch_size).to(self.featM.device)
        self.loss_neg = self.contrastive_loss(self.featM, self.featNM, neg_labels)
        self.loss = self.loss_pos + self.loss_neg

        loss_logs = OrderedDict({})
        loss_logs['loss'] = self.loss.item()
        loss_logs['loss_pos'] = self.loss_pos.item()
        loss_logs['loss_neg'] = self.loss_neg.item()
        return loss_logs

    def train(self, train_dataloader, val_dataloader):
        net_list = [self.model]
        self.to(net_list, self.opt.device)

        self.opt_model = optim.Adam(self.model.parameters(), lr=self.opt.lr)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss"], epoch)
            print(print_str)

            if val_loss["loss"] < min_val_loss:
                min_val_loss = val_loss["loss"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            # if epoch % self.opt.eval_every_e == 0:
            #     # B = self.M1.size(0)
            #     local = self.input[:2].permute(0, 2, 1)
            #     target = self.target[:2].permute(0, 2, 1)
            #     pred = self.pred[:2].permute(0, 2, 1)
            #     real = torch.cat([local[..., 0:1], target, local[..., 1:]], dim=-1)
            #     fake = torch.cat([local[..., 0:1], pred, local[..., 1:]], dim=-1)
            #     data = torch.cat([real, fake], dim=0).detach().cpu().numpy()
            #     save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
            #     os.makedirs(save_dir, exist_ok=True)
            #     plot_eval(data, save_dir)


class LatentVAETrainer(BaseTrainer):
    def __init__(self, opt, encoder, decoder, ae_encoder, ae_decoder):
        self.opt = opt
        self.encoder = encoder
        self.decoder = decoder
        self.ae_encoder = ae_encoder
        self.ae_decoder = ae_decoder
        # self.discriminator = discriminator

        if self.opt.is_train:
            self.logger = Logger(self.opt.log_dir)
            self.mse_criterion = torch.nn.MSELoss()
            self.l1_criterion = torch.nn.SmoothL1Loss()
            # self.cross_entropy_criterion = torch.nn.CrossEntropyLoss()

    @staticmethod
    def reparametrize(mu, logvar):
        s_var = logvar.mul(0.5).exp_()
        eps = s_var.data.new(s_var.size()).normal_()
        return eps.mul(s_var).add_(mu)

    @staticmethod
    def ones_like(tensor, val=1.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def zeros_like(tensor, val=0.):
        return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)

    @staticmethod
    def kl_criterion(mu1, logvar1, mu2, logvar2):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        sigma1 = logvar1.mul(0.5).exp()
        sigma2 = logvar2.mul(0.5).exp()
        kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / (
                2 * torch.exp(logvar2)) - 1 / 2
        return kld.sum() / np.prod(mu1.shape)

    @staticmethod
    def kl_criterion_unit(mu, logvar):
        # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
        # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
        kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2
        return kld.sum() / np.prod(mu.shape)

    def forward(self, batch_data):
        if self.opt.dataset_name == "cmu":
            M1, M2, A1, S1, SID1 = batch_data
        else:
            M1, M2, MS, _, A1, S1, SID1, _, _ = batch_data
        A2, S2 = A1, S1

        # if self.opt.use_style:
        #     M2 = MS.clone()
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        # M1[..., 1:3] *= 0
        # M2[..., 1:3] *= 0

        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()
        M3 = self.swap(M2)
        SID3 = self.swap(SID1)

        if self.opt.use_action:
            A1 = A1.to(self.opt.device).float().detach()
            A2 = A2.to(self.opt.device).float().detach()
            A3 = self.swap(A2)
        else:
            A1, A2, A3 = None, None, None

        if self.opt.use_style:
            S1 = S1.to(self.opt.device).float().detach()
            S2 = S2.to(self.opt.device).float().detach()
            S3 = self.swap(S2)
        else:
            S1, S2, S3 = None, None, None

        LM1, _, _ = self.ae_encoder(M1[:, :-4])
        LM2, _, _ = self.ae_encoder(M2[:, :-4])
        LM3 = self.swap(LM2)

        LM1, LM2, LM3 = LM1.detach(), LM2.detach(), LM3.detach()

        sp1, gl_mu1, gl_logvar1 = self.encoder(LM1, A1, S1)
        sp2, gl_mu2, gl_logvar2 = self.encoder(LM2, A2, S2)
        sp3, gl_mu3, gl_logvar3 = self.swap(sp2), self.swap(gl_mu2), self.swap(gl_logvar2)


        # z_sp1 = self.reparametrize(sp_mu1, sp_logvar1)
        z_sp1 = sp1
        z_gl1 = self.reparametrize(gl_mu1, gl_logvar1)

        # z_sp2 = self.reparametrize(sp_mu2, sp_logvar2)
        z_sp2 = sp2
        # May detach the graph of M1
        z_gl2 = self.reparametrize(gl_mu2, gl_logvar2)

        # May detach the graph of M2
        # z_sp3 = self.reparametrize(sp_mu2.detach(), sp_logvar2.detach())
        z_sp3 = z_sp2.detach()
        z_gl3 = self.reparametrize(gl_mu3, gl_logvar3)


        RLM1 = self.decoder(z_sp1, z_gl1, A1, S1)
        RLM2 = self.decoder(z_sp2, z_gl2, A2, S2)
        RLM3 = self.decoder(z_sp3, z_gl3, A2, S3)

        # print(RM1)
        # print(RM3)
        # Should be identical to M2
        # May detach from graph of RM3
        sp4, gl_mu4, gl_logvar4 = self.encoder(RLM3, A2, S3)


        z_sp4 = sp4
        # May detach from graph of M2
        z_gl4 = self.reparametrize(gl_mu2.detach(), gl_logvar2.detach())

        #  Should be identical to M3
        # May detach from graph of M3
        # z_sp5 = self.reparametrize(sp_mu3.detach(), sp_logvar3.detach())
        z_sp5 = sp3.detach()
        z_gl5 = self.reparametrize(gl_mu4, gl_logvar4)


        RRLM2 = self.decoder(z_sp4, z_gl4, A2, S2)
        RRLM3 = self.decoder(z_sp5, z_gl5, A3, S3)

        RM1 = self.ae_decoder(RLM1)
        RM2 = self.ae_decoder(RLM2)
        RM3 = self.ae_decoder(RLM3)
        RRM2 = self.ae_decoder(RRLM2)
        RRM3 = self.ae_decoder(RRLM3)

        self.M1, self.M2, self.M3 = M1, M2, M3
        self.LM1, self.LM2, self.LM3 = LM1, LM2, LM3
        self.RLM1, self.RLM2 = RLM1, RLM2
        self.RRLM2, self.RRLM3 = RRLM2, RRLM3
        self.RM1, self.RM2, self.RM3, self.RRM2, self.RRM3 = RM1, RM2, RM3, RRM2, RRM3
        self.SID1 = SID1
        self.SID3 = SID3
        self.sp2, self.sp4 = sp2, sp4
        self.gl_mu1, self.gl_mu2, self.gl_mu3, self.gl_mu4 = gl_mu1, gl_mu2, gl_mu3, gl_mu4
        self.gl_logvar1, self.gl_logvar2, self.gl_logvar3, self.gl_logvar4 = gl_logvar1, gl_logvar2, gl_logvar3, gl_logvar4

    def generate(self, M1, M2, S2, sampling):
        # M1, _, A1, S1, SID1 = batch_data
        # print("M1", M1)
        # print("M2", M2)
        # (B, seq_len, pose_dim) -> (B, pose_dim, seq_len)
        # OM1 = M1.clone()
        # OM2 = self.swap(OM1)
        M1 = M1.clone()
        M2 = M2.clone()
        # M1[..., 1:3] *= 0
        # M2[..., 1:3] *= 0
        M1 = M1.permute(0, 2, 1).to(self.opt.device).float().detach()
        M2 = M2.permute(0, 2, 1).to(self.opt.device).float().detach()

        if self.opt.use_style:
            S2 = S2.to(self.opt.device).float().detach()
        else:
            S2 = None
        LM1, _, _ = self.ae_encoder(M1[:, :-4])
        LM2, _, _ = self.ae_encoder(M2[:, :-4])

        # print(M1[:, :-4].shape, M2[:, :-4].shape, S1.shape, S2.shape)
        sp1 = self.encoder.extract_content_feature(LM1, None)
        gl_mu2, gl_logvar2 = self.encoder.extract_style_feature(LM2, S2)

        # z_sp = self.reparametrize(sp_mu1, sp_logvar1)
        z_sp = sp1
        if sampling:
            # Sample from normal distribution, novel style generation
            z_gl = self.reparametrize(self.zeros_like(gl_mu2), self.zeros_like(gl_logvar2))
        else:
            # Sample from M2 distribution, motion style transfer
            z_gl = self.reparametrize(gl_mu2, gl_logvar2)

        TLM = self.decoder(z_sp, z_gl, None, S2)
        TM = self.ae_decoder(TLM)

        return TM.permute(0, 2, 1)


    def backward(self):
        self.loss_rec_lm1 = self.l1_criterion(self.LM1, self.RLM1)
        self.loss_rec_lm2 = self.l1_criterion(self.LM2, self.RLM2)
        self.loss_rec_m1 = self.l1_criterion(self.M1, self.RM1)
        self.loss_rec_m2 = self.l1_criterion(self.M2, self.RM2)

        self.loss_rec_rlm2 = self.l1_criterion(self.LM2, self.RRLM2)
        self.loss_rec_rlm3 = self.l1_criterion(self.LM3, self.RRLM3)
        self.loss_rec_rm2 = self.l1_criterion(self.M2, self.RRM2)
        self.loss_rec_rm3 = self.l1_criterion(self.M3, self.RRM3)

        self.loss_rec_lat = self.l1_criterion(self.sp2, self.sp4)

        # print(self.loss_rec_m1, self.loss_rec_m2, self.loss_rec_rm2, self.loss_rec_rm3)

        self.loss_kld_gl_m1 = self.kl_criterion_unit(self.gl_mu1, self.gl_logvar1)
        self.loss_kld_gl_m2 = self.kl_criterion_unit(self.gl_mu2, self.gl_logvar2)
        self.loss_kld_gl_m4 = self.kl_criterion_unit(self.gl_mu4, self.gl_logvar4)
        self.loss_kld_gl_m12 = self.kl_criterion(self.gl_mu1, self.gl_logvar1, self.gl_mu2, self.gl_logvar2)
        self.loss_kld_gl_m34 = self.kl_criterion(self.gl_mu3, self.gl_logvar3, self.gl_mu4, self.gl_logvar4)

        self.loss = (self.loss_rec_lm1 + self.loss_rec_lm2 + self.loss_rec_m1 + self.loss_rec_m2) * self.opt.lambda_rec + \
                    (self.loss_rec_rlm2 + self.loss_rec_rlm3 + self.loss_rec_rm2 + self.loss_rec_rm3) * self.opt.lambda_rec_c + \
                    (self.loss_kld_gl_m1 + self.loss_kld_gl_m2 + self.loss_kld_gl_m4) * self.opt.lambda_kld_gl + \
                    self.loss_kld_gl_m12 * self.opt.lambda_kld_gl12


        loss_logs = OrderedDict({})
        loss_logs["loss"] = self.loss.item()
        loss_logs["loss_rec_lm1"] = self.loss_rec_lm1.item()
        loss_logs["loss_rec_lm2"] = self.loss_rec_lm2.item()
        loss_logs["loss_rec_m1"] = self.loss_rec_m1.item()
        loss_logs["loss_rec_m2"] = self.loss_rec_m2.item()

        loss_logs["loss_rec_rlm2"] = self.loss_rec_rlm2.item()
        loss_logs["loss_rec_rlm3"] = self.loss_rec_rlm3.item()
        loss_logs["loss_rec_rm2"] = self.loss_rec_rm2.item()
        loss_logs["loss_rec_rm3"] = self.loss_rec_rm3.item()
        loss_logs["loss_rec_lat"] = self.loss_rec_lat.item()

        loss_logs["loss_kld_gl_m1"] = self.loss_kld_gl_m1.item()
        loss_logs["loss_kld_gl_m2"] = self.loss_kld_gl_m2.item()
        loss_logs["loss_kld_gl_m4"] = self.loss_kld_gl_m4.item()
        loss_logs["loss_kld_gl_m12"] = self.loss_kld_gl_m12.item()
        loss_logs["loss_kld_gl_m34"] = self.loss_kld_gl_m34.item()

        return loss_logs

    def update(self):
        self.zero_grad([self.opt_encoder, self.opt_decoder])
        loss_logs = self.backward()
        self.loss.backward()
        self.clip_norm([self.encoder, self.decoder])
        self.step([self.opt_encoder, self.opt_decoder])
        return loss_logs

    def save(self, file_name, ep, total_it):
        state = {
            "encoder": self.encoder.state_dict(),
            "decoder": self.decoder.state_dict(),
            "ae_encoder": self.ae_encoder.state_dict(),
            "ae_decoder": self.ae_decoder.state_dict(),

            "opt_encoder": self.opt_encoder.state_dict(),
            "opt_decoder": self.opt_decoder.state_dict(),
            "opt_ae_decoder": self.opt_ae_decoder.state_dict(),

            "ep": ep,
            "total_it": total_it,
        }

        torch.save(state, file_name)

    def resume(self, model_dir):
        # print(model_dir)
        checkpoint = torch.load(model_dir, map_location=self.opt.device)
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])
        self.ae_encoder.load_state_dict(checkpoint["ae_encoder"])
        self.ae_decoder.load_state_dict(checkpoint["ae_decoder"])

        if self.opt.is_train:
            self.opt_encoder.load_state_dict(checkpoint["opt_encoder"])
            self.opt_decoder.load_state_dict(checkpoint["opt_decoder"])
            self.opt_ae_decoder.load_state_dict(checkpoint["opt_ae_decoder"])
        print("Loading the model from epoch %04d"%checkpoint["ep"])
        return checkpoint["ep"], checkpoint["total_it"]

    def train(self, train_dataloader, val_dataloader, plot_eval):
        net_list = [self.encoder, self.decoder, self.ae_encoder, self.ae_decoder]
        self.to(net_list, self.opt.device)

        self.opt_encoder = optim.Adam(self.encoder.parameters(), lr=self.opt.lr)
        self.opt_decoder = optim.Adam(self.decoder.parameters(), lr=self.opt.lr)
        self.opt_ae_decoder = optim.Adam(self.decoder.parameters(), lr=self.opt.lr*0.1)

        epoch = 0
        it = 0
        if self.opt.is_continue:
            model_dir = pjoin(self.opt.model_dir, "latest.tar")
            epoch, it = self.resume(model_dir)
            print("Loading model from Epoch %d" % (epoch))

        start_time = time.time()
        total_iters = self.opt.max_epoch * len(train_dataloader)
        print("Iters Per Epoch, Training: %04d, Validation: %03d" % (len(train_dataloader), len(val_dataloader)))
        min_val_loss = np.inf
        # val_loss = 0
        logs = OrderedDict()

        while epoch < self.opt.max_epoch:
            self.net_train(net_list)
            for i, batch_data in enumerate(train_dataloader):
                self.forward(batch_data)
                loss_dict = self.update()

                for k, v in loss_dict.items():
                    if k not in logs:
                        logs[k] = v
                    else:
                        logs[k] += v
                it += 1
                if it % self.opt.log_every == 0:
                    mean_loss = OrderedDict()
                    # self.logger.scalar_summary("val_loss", val_loss, it)

                    for tag, value in logs.items():
                        self.logger.scalar_summary(tag, value / self.opt.log_every, it)
                        mean_loss[tag] = value / self.opt.log_every
                    logs = OrderedDict()
                    print_current_loss(start_time, it, total_iters, mean_loss, epoch, i)

                if it % self.opt.save_latest == 0:
                    self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)

            self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
            epoch += 1
            # if epoch % self.opt.save_every_e == 0:
            #     self.save(pjoin(self.opt.model_dir, "E%04d.tar"%(epoch)), epoch, total_it=it)

            print("Validation time:")
            val_loss = None
            with torch.no_grad():
                self.net_eval(net_list)
                for i, batch_data in enumerate(val_dataloader):
                    self.forward(batch_data)
                    loss_dict = self.backward()
                    if val_loss is None:
                        val_loss = loss_dict
                    else:
                        for k, v in loss_dict.items():
                            val_loss[k] += v

            print_str = "Validation Loss:"
            for k, v in val_loss.items():
                val_loss[k] /= len(val_dataloader)
                # self.logger.scalar_summary(k, val_loss[k], epoch)
                print_str += ' %s: %.4f ' % (k, val_loss[k])
            self.logger.scalar_summary("val_loss", val_loss["loss"], epoch)
            print(print_str)

            if val_loss["loss"] < min_val_loss:
                min_val_loss = val_loss["loss"]
                min_val_epoch = epoch
                self.save(pjoin(self.opt.model_dir, "best.tar"), epoch, it)
                print("Best Validation Model So Far!~")

            if epoch % self.opt.eval_every_e == 0:
                B = self.M1.size(0)
                data = torch.cat([self.M2[:6:2], self.RM2[:6:2], self.M3[:6:2], self.RM3[:6:2]],
                                 dim=0)
                styles = torch.cat([self.SID1[:6:2], self.SID1[:6:2], self.SID3[:6:2], self.SID3[:6:2]],
                                   dim=0).detach().cpu().numpy()
                data = data.permute(0, 2, 1).detach().cpu().numpy()
                save_dir = pjoin(self.opt.eval_dir, "E%04d" % (epoch))
                os.makedirs(save_dir, exist_ok=True)
                plot_eval(data, save_dir, styles)
